diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 40d403c03d..7f2d2783f2 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -199,6 +199,7 @@ class DenoiseMaskField(BaseModel): mask_name: str = Field(description="The name of the mask image") masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + gradient: Optional[bool] = Field(default=False, description="Used for gradient inpainting") class LatentsField(BaseModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index bfe7255b62..97d3c705d4 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,7 +23,7 @@ from diffusers.models.attention_processor import ( from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler -from PIL import Image +from PIL import Image, ImageFilter from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize @@ -128,7 +128,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ui_order=4, ) - def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: + def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor: if mask_image.mode != "L": mask_image = mask_image.convert("L") mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) @@ -169,6 +169,62 @@ class CreateDenoiseMaskInvocation(BaseInvocation): return DenoiseMaskOutput.build( mask_name=mask_name, masked_latents_name=masked_latents_name, + gradient=False, + ) + + +@invocation( + "create_gradient_mask", + title="Create Gradient Mask", + tags=["mask", "denoise"], + category="latents", + version="1.0.0", +) +class CreateGradientMaskInvocation(BaseInvocation): + """Creates mask for denoising model run.""" + + mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1) + edge_radius: int = InputField( + default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2 + ) + coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3) + minimum_denoise: float = InputField( + default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4 + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + mask_image = context.images.get_pil(self.mask.image_name, mode="L") + if self.coherence_mode == "Box Blur": + blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) + else: # Gaussian Blur OR Staged + # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation + blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) + + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) + + # redistribute blur so that the edges are 0 and blur out to 1 + blur_tensor = (blur_tensor - 0.5) * 2 + + threshold = 1 - self.minimum_denoise + + if self.coherence_mode == "Staged": + # wherever the blur_tensor is masked to any degree, convert it to threshold + blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor) + else: + # wherever the blur_tensor is above threshold but less than 1, drop it to threshold + blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) + + # multiply original mask to force actually masked regions to 0 + blur_tensor = mask_tensor * blur_tensor + + mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) + + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=None, + gradient=True, ) @@ -606,9 +662,9 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_inpaint_mask( self, context: InvocationContext, latents: torch.Tensor - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: if self.denoise_mask is None: - return None, None + return None, None, False mask = context.tensors.load(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) @@ -617,7 +673,7 @@ class DenoiseLatentsInvocation(BaseInvocation): else: masked_latents = None - return 1 - mask, masked_latents + return 1 - mask, masked_latents, self.denoise_mask.gradient @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: @@ -644,7 +700,7 @@ class DenoiseLatentsInvocation(BaseInvocation): if seed is None: seed = 0 - mask, masked_latents = self.prep_inpaint_mask(context, latents) + mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # below. Investigate whether this is appropriate. @@ -732,6 +788,7 @@ class DenoiseLatentsInvocation(BaseInvocation): seed=seed, mask=mask, masked_latents=masked_latents, + gradient_mask=gradient_mask, num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, control_data=controlnet_data, diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 4342213482..c761bb0895 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -299,9 +299,13 @@ class DenoiseMaskOutput(BaseInvocationOutput): denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") @classmethod - def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": + def build( + cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: Optional[bool] = False + ) -> "DenoiseMaskOutput": return cls( - denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), + denoise_mask=DenoiseMaskField( + mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient + ), ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index a85e3762dc..fd3ecde47b 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -86,6 +86,7 @@ class AddsMaskGuidance: mask_latents: torch.FloatTensor scheduler: SchedulerMixin noise: torch.Tensor + gradient_mask: bool def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput: output_class = step_output.__class__ # We'll create a new one with masked data. @@ -121,7 +122,12 @@ class AddsMaskGuidance: # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # mask_latents = self.scheduler.scale_model_input(mask_latents, t) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) - masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) + if self.gradient_mask: + threshhold = (t.item()) / self.scheduler.config.num_train_timesteps + mask_bool = mask > threshhold # I don't know when mask got inverted, but it did + masked_input = torch.where(mask_bool, latents, mask_latents) + else: + masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) return masked_input @@ -335,6 +341,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): t2i_adapter_data: Optional[list[T2IAdapterData]] = None, mask: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None, + gradient_mask: Optional[bool] = False, seed: Optional[int] = None, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if init_timestep.shape[0] == 0: @@ -375,7 +382,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self._unet_forward, mask, masked_latents ) else: - additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise)) + additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)) try: latents, attention_map_saver = self.generate_latents_from_embeddings( @@ -392,7 +399,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.invokeai_diffuser.model_forward_callback = self._unet_forward # restore unmasked part - if mask is not None: + if mask is not None and not gradient_mask: latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) return latents, attention_map_saver