mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): added gradient mask node
This commit is contained in:
parent
970d45f691
commit
b2b7aed030
@ -199,6 +199,7 @@ class DenoiseMaskField(BaseModel):
|
|||||||
|
|
||||||
mask_name: str = Field(description="The name of the mask image")
|
mask_name: str = Field(description="The name of the mask image")
|
||||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||||
|
gradient: Optional[bool] = Field(default=False, description="Used for gradient inpainting")
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
|
@ -23,7 +23,7 @@ from diffusers.models.attention_processor import (
|
|||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from PIL import Image
|
from PIL import Image, ImageFilter
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
ui_order=4,
|
ui_order=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
|
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||||
if mask_image.mode != "L":
|
if mask_image.mode != "L":
|
||||||
mask_image = mask_image.convert("L")
|
mask_image = mask_image.convert("L")
|
||||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
@ -169,6 +169,62 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
return DenoiseMaskOutput.build(
|
return DenoiseMaskOutput.build(
|
||||||
mask_name=mask_name,
|
mask_name=mask_name,
|
||||||
masked_latents_name=masked_latents_name,
|
masked_latents_name=masked_latents_name,
|
||||||
|
gradient=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"create_gradient_mask",
|
||||||
|
title="Create Gradient Mask",
|
||||||
|
tags=["mask", "denoise"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class CreateGradientMaskInvocation(BaseInvocation):
|
||||||
|
"""Creates mask for denoising model run."""
|
||||||
|
|
||||||
|
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||||
|
edge_radius: int = InputField(
|
||||||
|
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
|
||||||
|
)
|
||||||
|
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
|
||||||
|
minimum_denoise: float = InputField(
|
||||||
|
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||||
|
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||||
|
if self.coherence_mode == "Box Blur":
|
||||||
|
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||||
|
else: # Gaussian Blur OR Staged
|
||||||
|
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||||
|
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||||
|
|
||||||
|
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
|
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||||
|
|
||||||
|
# redistribute blur so that the edges are 0 and blur out to 1
|
||||||
|
blur_tensor = (blur_tensor - 0.5) * 2
|
||||||
|
|
||||||
|
threshold = 1 - self.minimum_denoise
|
||||||
|
|
||||||
|
if self.coherence_mode == "Staged":
|
||||||
|
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||||
|
else:
|
||||||
|
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||||
|
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||||
|
|
||||||
|
# multiply original mask to force actually masked regions to 0
|
||||||
|
blur_tensor = mask_tensor * blur_tensor
|
||||||
|
|
||||||
|
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||||
|
|
||||||
|
return DenoiseMaskOutput.build(
|
||||||
|
mask_name=mask_name,
|
||||||
|
masked_latents_name=None,
|
||||||
|
gradient=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -606,9 +662,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def prep_inpaint_mask(
|
def prep_inpaint_mask(
|
||||||
self, context: InvocationContext, latents: torch.Tensor
|
self, context: InvocationContext, latents: torch.Tensor
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
|
||||||
if self.denoise_mask is None:
|
if self.denoise_mask is None:
|
||||||
return None, None
|
return None, None, False
|
||||||
|
|
||||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
@ -617,7 +673,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
masked_latents = None
|
masked_latents = None
|
||||||
|
|
||||||
return 1 - mask, masked_latents
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
@ -644,7 +700,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if seed is None:
|
if seed is None:
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
@ -732,6 +788,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
masked_latents=masked_latents,
|
masked_latents=masked_latents,
|
||||||
|
gradient_mask=gradient_mask,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=controlnet_data,
|
control_data=controlnet_data,
|
||||||
|
@ -299,9 +299,13 @@ class DenoiseMaskOutput(BaseInvocationOutput):
|
|||||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
|
def build(
|
||||||
|
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: Optional[bool] = False
|
||||||
|
) -> "DenoiseMaskOutput":
|
||||||
return cls(
|
return cls(
|
||||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
|
denoise_mask=DenoiseMaskField(
|
||||||
|
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,6 +86,7 @@ class AddsMaskGuidance:
|
|||||||
mask_latents: torch.FloatTensor
|
mask_latents: torch.FloatTensor
|
||||||
scheduler: SchedulerMixin
|
scheduler: SchedulerMixin
|
||||||
noise: torch.Tensor
|
noise: torch.Tensor
|
||||||
|
gradient_mask: bool
|
||||||
|
|
||||||
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
||||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||||
@ -121,7 +122,12 @@ class AddsMaskGuidance:
|
|||||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
if self.gradient_mask:
|
||||||
|
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
|
||||||
|
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||||
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
|
else:
|
||||||
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
return masked_input
|
return masked_input
|
||||||
|
|
||||||
|
|
||||||
@ -335,6 +341,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
|
gradient_mask: Optional[bool] = False,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
@ -375,7 +382,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self._unet_forward, mask, masked_latents
|
self._unet_forward, mask, masked_latents
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
||||||
@ -392,7 +399,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||||
|
|
||||||
# restore unmasked part
|
# restore unmasked part
|
||||||
if mask is not None:
|
if mask is not None and not gradient_mask:
|
||||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||||
|
|
||||||
return latents, attention_map_saver
|
return latents, attention_map_saver
|
||||||
|
Loading…
Reference in New Issue
Block a user