Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov 2024-07-23 23:34:28 +03:00
parent 9d1fcba415
commit c323a760a5
3 changed files with 68 additions and 23 deletions

View File

@ -732,10 +732,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype = TorchDevice.choose_torch_dtype() dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape _, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data( conditioning_data = self.get_conditioning_data(
@ -768,21 +764,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
) )
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback() # get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key) unet_config = context.models.get_config(self.unet.unet.key)
@ -799,6 +780,26 @@ class DenoiseLatentsInvocation(BaseInvocation):
elif mask is not None: elif mask is not None:
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask)) ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
# Initialize context for modular denoise
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# ext: t2i/ip adapter # ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

View File

@ -14,18 +14,40 @@ if TYPE_CHECKING:
class InpaintExt(ExtensionBase): class InpaintExt(ExtensionBase):
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
models.
"""
def __init__( def __init__(
self, self,
mask: torch.Tensor, mask: torch.Tensor,
is_gradient_mask: bool, is_gradient_mask: bool,
): ):
"""Initialize InpaintExt.
Args:
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
inpainted.
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
1.
"""
super().__init__() super().__init__()
self._mask = mask self._mask = mask
self._is_gradient_mask = is_gradient_mask self._is_gradient_mask = is_gradient_mask
# Noise, which used to noisify unmasked part of image
# if noise provided to context, then it will be used
# if no noise provided, then noise will be generated based on seed
self._noise: Optional[torch.Tensor] = None self._noise: Optional[torch.Tensor] = None
@staticmethod @staticmethod
def _is_normal_model(unet: UNet2DConditionModel): def _is_normal_model(unet: UNet2DConditionModel):
""" Checks if the provided UNet belongs to a regular model.
The `in_channels` of a UNet vary depending on model type:
- normal - 4
- depth - 5
- inpaint - 9
"""
return unet.conv_in.in_channels == 4 return unet.conv_in.in_channels == 4
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
@ -42,8 +64,8 @@ class InpaintExt(ExtensionBase):
# mask_latents = self.scheduler.scale_model_input(mask_latents, t) # mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self._is_gradient_mask: if self._is_gradient_mask:
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did mask_bool = mask > threshold
masked_input = torch.where(mask_bool, latents, mask_latents) masked_input = torch.where(mask_bool, latents, mask_latents)
else: else:
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
@ -52,11 +74,13 @@ class InpaintExt(ExtensionBase):
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext): def init_tensors(self, ctx: DenoiseContext):
if not self._is_normal_model(ctx.unet): if not self._is_normal_model(ctx.unet):
raise Exception("InpaintExt should be used only on normal models!") raise ValueError("InpaintExt should be used only on normal models!")
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
self._noise = ctx.inputs.noise self._noise = ctx.inputs.noise
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
# We still need noise for inpainting, so we generate it from the seed here.
if self._noise is None: if self._noise is None:
self._noise = torch.randn( self._noise = torch.randn(
ctx.latents.shape, ctx.latents.shape,

View File

@ -13,12 +13,26 @@ if TYPE_CHECKING:
class InpaintModelExt(ExtensionBase): class InpaintModelExt(ExtensionBase):
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
models.
"""
def __init__( def __init__(
self, self,
mask: Optional[torch.Tensor], mask: Optional[torch.Tensor],
masked_latents: Optional[torch.Tensor], masked_latents: Optional[torch.Tensor],
is_gradient_mask: bool, is_gradient_mask: bool,
): ):
"""Initialize InpaintModelExt.
Args:
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
inpainted.
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
1.
"""
super().__init__() super().__init__()
if mask is not None and masked_latents is None: if mask is not None and masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!") raise ValueError("Source image required for inpaint mask when inpaint model used!")
@ -29,12 +43,18 @@ class InpaintModelExt(ExtensionBase):
@staticmethod @staticmethod
def _is_inpaint_model(unet: UNet2DConditionModel): def _is_inpaint_model(unet: UNet2DConditionModel):
""" Checks if the provided UNet belongs to a regular model.
The `in_channels` of a UNet vary depending on model type:
- normal - 4
- depth - 5
- inpaint - 9
"""
return unet.conv_in.in_channels == 9 return unet.conv_in.in_channels == 9
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext): def init_tensors(self, ctx: DenoiseContext):
if not self._is_inpaint_model(ctx.unet): if not self._is_inpaint_model(ctx.unet):
raise Exception("InpaintModelExt should be used only on inpaint models!") raise ValueError("InpaintModelExt should be used only on inpaint models!")
if self._mask is None: if self._mask is None:
self._mask = torch.ones_like(ctx.latents[:1, :1]) self._mask = torch.ones_like(ctx.latents[:1, :1])