mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use non-inverted mask generally(except inpaint model handling)
This commit is contained in:
parent
c323a760a5
commit
19c00241c6
@ -674,7 +674,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||||
|
|
||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_noise_and_latents(
|
def prepare_noise_and_latents(
|
||||||
@ -830,6 +830,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
if mask is not None:
|
||||||
|
mask = 1 - mask
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
|
@ -25,7 +25,7 @@ class InpaintExt(ExtensionBase):
|
|||||||
"""Initialize InpaintExt.
|
"""Initialize InpaintExt.
|
||||||
Args:
|
Args:
|
||||||
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
inpainted.
|
inpainted.
|
||||||
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
@ -65,10 +65,10 @@ class InpaintExt(ExtensionBase):
|
|||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
if self._is_gradient_mask:
|
if self._is_gradient_mask:
|
||||||
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||||
mask_bool = mask > threshold
|
mask_bool = mask < 1 - threshold
|
||||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
else:
|
else:
|
||||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
|
||||||
return masked_input
|
return masked_input
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
@ -111,6 +111,6 @@ class InpaintExt(ExtensionBase):
|
|||||||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
def restore_unmasked(self, ctx: DenoiseContext):
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
if self._is_gradient_mask:
|
if self._is_gradient_mask:
|
||||||
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
|
||||||
else:
|
else:
|
||||||
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
|
||||||
|
@ -25,7 +25,7 @@ class InpaintModelExt(ExtensionBase):
|
|||||||
"""Initialize InpaintModelExt.
|
"""Initialize InpaintModelExt.
|
||||||
Args:
|
Args:
|
||||||
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
inpainted.
|
inpainted.
|
||||||
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
||||||
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
||||||
@ -37,7 +37,10 @@ class InpaintModelExt(ExtensionBase):
|
|||||||
if mask is not None and masked_latents is None:
|
if mask is not None and masked_latents is None:
|
||||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
|
|
||||||
self._mask = mask
|
# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
|
||||||
|
self._mask = None
|
||||||
|
if mask is not None:
|
||||||
|
self._mask = 1 - mask
|
||||||
self._masked_latents = masked_latents
|
self._masked_latents = masked_latents
|
||||||
self._is_gradient_mask = is_gradient_mask
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user