feat(nodes): added gradient mask node

This commit is contained in:
dunkeroni 2024-02-20 21:13:19 -05:00 committed by Kent Keirsey
parent 970d45f691
commit b2b7aed030
4 changed files with 80 additions and 11 deletions

View File

@ -199,6 +199,7 @@ class DenoiseMaskField(BaseModel):
mask_name: str = Field(description="The name of the mask image") mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
gradient: Optional[bool] = Field(default=False, description="Used for gradient inpainting")
class LatentsField(BaseModel): class LatentsField(BaseModel):

View File

@ -23,7 +23,7 @@ from diffusers.models.attention_processor import (
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image from PIL import Image, ImageFilter
from pydantic import field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -128,7 +128,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
ui_order=4, ui_order=4,
) )
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L": if mask_image.mode != "L":
mask_image = mask_image.convert("L") mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
@ -169,6 +169,62 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
return DenoiseMaskOutput.build( return DenoiseMaskOutput.build(
mask_name=mask_name, mask_name=mask_name,
masked_latents_name=masked_latents_name, masked_latents_name=masked_latents_name,
gradient=False,
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the edges are 0 and blur out to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# multiply original mask to force actually masked regions to 0
blur_tensor = mask_tensor * blur_tensor
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=None,
gradient=True,
) )
@ -606,9 +662,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_inpaint_mask( def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
if self.denoise_mask is None: if self.denoise_mask is None:
return None, None return None, None, False
mask = context.tensors.load(self.denoise_mask.mask_name) mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
@ -617,7 +673,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
else: else:
masked_latents = None masked_latents = None
return 1 - mask, masked_latents return 1 - mask, masked_latents, self.denoise_mask.gradient
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -644,7 +700,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if seed is None: if seed is None:
seed = 0 seed = 0
mask, masked_latents = self.prep_inpaint_mask(context, latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate. # below. Investigate whether this is appropriate.
@ -732,6 +788,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed, seed=seed,
mask=mask, mask=mask,
masked_latents=masked_latents, masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, control_data=controlnet_data,

View File

@ -299,9 +299,13 @@ class DenoiseMaskOutput(BaseInvocationOutput):
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
@classmethod @classmethod
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": def build(
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: Optional[bool] = False
) -> "DenoiseMaskOutput":
return cls( return cls(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), denoise_mask=DenoiseMaskField(
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
),
) )

View File

@ -86,6 +86,7 @@ class AddsMaskGuidance:
mask_latents: torch.FloatTensor mask_latents: torch.FloatTensor
scheduler: SchedulerMixin scheduler: SchedulerMixin
noise: torch.Tensor noise: torch.Tensor
gradient_mask: bool
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput: def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data. output_class = step_output.__class__ # We'll create a new one with masked data.
@ -121,7 +122,12 @@ class AddsMaskGuidance:
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t) # mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) if self.gradient_mask:
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
masked_input = torch.where(mask_bool, latents, mask_latents)
else:
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
return masked_input return masked_input
@ -335,6 +341,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if init_timestep.shape[0] == 0: if init_timestep.shape[0] == 0:
@ -375,7 +382,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._unet_forward, mask, masked_latents self._unet_forward, mask, masked_latents
) )
else: else:
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise)) additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try: try:
latents, attention_map_saver = self.generate_latents_from_embeddings( latents, attention_map_saver = self.generate_latents_from_embeddings(
@ -392,7 +399,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.invokeai_diffuser.model_forward_callback = self._unet_forward self.invokeai_diffuser.model_forward_callback = self._unet_forward
# restore unmasked part # restore unmasked part
if mask is not None: if mask is not None and not gradient_mask:
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
return latents, attention_map_saver return latents, attention_map_saver