mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
@ -14,18 +14,40 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class InpaintExt(ExtensionBase):
|
||||
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
||||
models.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
mask: torch.Tensor,
|
||||
is_gradient_mask: bool,
|
||||
):
|
||||
"""Initialize InpaintExt.
|
||||
Args:
|
||||
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
|
||||
inpainted.
|
||||
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||
1.
|
||||
"""
|
||||
super().__init__()
|
||||
self._mask = mask
|
||||
self._is_gradient_mask = is_gradient_mask
|
||||
|
||||
# Noise, which used to noisify unmasked part of image
|
||||
# if noise provided to context, then it will be used
|
||||
# if no noise provided, then noise will be generated based on seed
|
||||
self._noise: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def _is_normal_model(unet: UNet2DConditionModel):
|
||||
""" Checks if the provided UNet belongs to a regular model.
|
||||
The `in_channels` of a UNet vary depending on model type:
|
||||
- normal - 4
|
||||
- depth - 5
|
||||
- inpaint - 9
|
||||
"""
|
||||
return unet.conv_in.in_channels == 4
|
||||
|
||||
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
@ -42,8 +64,8 @@ class InpaintExt(ExtensionBase):
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if self._is_gradient_mask:
|
||||
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||
mask_bool = mask > threshold
|
||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||
else:
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
@ -52,11 +74,13 @@ class InpaintExt(ExtensionBase):
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def init_tensors(self, ctx: DenoiseContext):
|
||||
if not self._is_normal_model(ctx.unet):
|
||||
raise Exception("InpaintExt should be used only on normal models!")
|
||||
raise ValueError("InpaintExt should be used only on normal models!")
|
||||
|
||||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||
|
||||
self._noise = ctx.inputs.noise
|
||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
# We still need noise for inpainting, so we generate it from the seed here.
|
||||
if self._noise is None:
|
||||
self._noise = torch.randn(
|
||||
ctx.latents.shape,
|
||||
|
@ -13,12 +13,26 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class InpaintModelExt(ExtensionBase):
|
||||
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
|
||||
models.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
mask: Optional[torch.Tensor],
|
||||
masked_latents: Optional[torch.Tensor],
|
||||
is_gradient_mask: bool,
|
||||
):
|
||||
"""Initialize InpaintModelExt.
|
||||
Args:
|
||||
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
|
||||
inpainted.
|
||||
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
||||
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
||||
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||
1.
|
||||
"""
|
||||
super().__init__()
|
||||
if mask is not None and masked_latents is None:
|
||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||
@ -29,12 +43,18 @@ class InpaintModelExt(ExtensionBase):
|
||||
|
||||
@staticmethod
|
||||
def _is_inpaint_model(unet: UNet2DConditionModel):
|
||||
""" Checks if the provided UNet belongs to a regular model.
|
||||
The `in_channels` of a UNet vary depending on model type:
|
||||
- normal - 4
|
||||
- depth - 5
|
||||
- inpaint - 9
|
||||
"""
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def init_tensors(self, ctx: DenoiseContext):
|
||||
if not self._is_inpaint_model(ctx.unet):
|
||||
raise Exception("InpaintModelExt should be used only on inpaint models!")
|
||||
raise ValueError("InpaintModelExt should be used only on inpaint models!")
|
||||
|
||||
if self._mask is None:
|
||||
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
||||
|
Reference in New Issue
Block a user