Use non-inverted mask generally(except inpaint model handling)

This commit is contained in:
Sergey Borisov 2024-07-24 00:59:13 +03:00
parent c323a760a5
commit 19c00241c6
3 changed files with 13 additions and 8 deletions

View File

@ -674,7 +674,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
else: else:
masked_latents = torch.where(mask < 0.5, 0.0, latents) masked_latents = torch.where(mask < 0.5, 0.0, latents)
return 1 - mask, masked_latents, self.denoise_mask.gradient return mask, masked_latents, self.denoise_mask.gradient
@staticmethod @staticmethod
def prepare_noise_and_latents( def prepare_noise_and_latents(
@ -830,6 +830,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
if mask is not None:
mask = 1 - mask
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate. # below. Investigate whether this is appropriate.

View File

@ -25,7 +25,7 @@ class InpaintExt(ExtensionBase):
"""Initialize InpaintExt. """Initialize InpaintExt.
Args: Args:
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
inpainted. inpainted.
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
@ -65,10 +65,10 @@ class InpaintExt(ExtensionBase):
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self._is_gradient_mask: if self._is_gradient_mask:
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
mask_bool = mask > threshold mask_bool = mask < 1 - threshold
masked_input = torch.where(mask_bool, latents, mask_latents) masked_input = torch.where(mask_bool, latents, mask_latents)
else: else:
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
return masked_input return masked_input
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
@ -111,6 +111,6 @@ class InpaintExt(ExtensionBase):
@callback(ExtensionCallbackType.POST_DENOISE_LOOP) @callback(ExtensionCallbackType.POST_DENOISE_LOOP)
def restore_unmasked(self, ctx: DenoiseContext): def restore_unmasked(self, ctx: DenoiseContext):
if self._is_gradient_mask: if self._is_gradient_mask:
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents) ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
else: else:
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask) ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)

View File

@ -25,7 +25,7 @@ class InpaintModelExt(ExtensionBase):
"""Initialize InpaintModelExt. """Initialize InpaintModelExt.
Args: Args:
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
inpainted. inpainted.
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area. masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width) If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
@ -37,7 +37,10 @@ class InpaintModelExt(ExtensionBase):
if mask is not None and masked_latents is None: if mask is not None and masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!") raise ValueError("Source image required for inpaint mask when inpaint model used!")
self._mask = mask # Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
self._mask = None
if mask is not None:
self._mask = 1 - mask
self._masked_latents = masked_latents self._masked_latents = masked_latents
self._is_gradient_mask = is_gradient_mask self._is_gradient_mask = is_gradient_mask