diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 3a9e0291af..0d9293be02 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -732,10 +732,6 @@ class DenoiseLatentsInvocation(BaseInvocation): dtype = TorchDevice.choose_torch_dtype() seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) - latents = latents.to(device=device, dtype=dtype) - if noise is not None: - noise = noise.to(device=device, dtype=dtype) - _, _, latent_height, latent_width = latents.shape conditioning_data = self.get_conditioning_data( @@ -768,21 +764,6 @@ class DenoiseLatentsInvocation(BaseInvocation): denoising_end=self.denoising_end, ) - denoise_ctx = DenoiseContext( - inputs=DenoiseInputs( - orig_latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - attention_processor_cls=CustomAttnProcessor2_0, - ), - unet=None, - scheduler=scheduler, - ) - # get the unet's config so that we can pass the base to sd_step_callback() unet_config = context.models.get_config(self.unet.unet.key) @@ -799,6 +780,26 @@ class DenoiseLatentsInvocation(BaseInvocation): elif mask is not None: ext_manager.add_extension(InpaintExt(mask, is_gradient_mask)) + # Initialize context for modular denoise + latents = latents.to(device=device, dtype=dtype) + if noise is not None: + noise = noise.to(device=device, dtype=dtype) + + denoise_ctx = DenoiseContext( + inputs=DenoiseInputs( + orig_latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, + ), + unet=None, + scheduler=scheduler, + ) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py index 27ea0a4ed6..fa58958b47 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -14,18 +14,40 @@ if TYPE_CHECKING: class InpaintExt(ExtensionBase): + """An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting + models. + """ def __init__( self, mask: torch.Tensor, is_gradient_mask: bool, ): + """Initialize InpaintExt. + Args: + mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are + expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be + inpainted. + is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range + from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or + 1. + """ super().__init__() self._mask = mask self._is_gradient_mask = is_gradient_mask + + # Noise, which used to noisify unmasked part of image + # if noise provided to context, then it will be used + # if no noise provided, then noise will be generated based on seed self._noise: Optional[torch.Tensor] = None @staticmethod def _is_normal_model(unet: UNet2DConditionModel): + """ Checks if the provided UNet belongs to a regular model. + The `in_channels` of a UNet vary depending on model type: + - normal - 4 + - depth - 5 + - inpaint - 9 + """ return unet.conv_in.in_channels == 4 def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: @@ -42,8 +64,8 @@ class InpaintExt(ExtensionBase): # mask_latents = self.scheduler.scale_model_input(mask_latents, t) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) if self._is_gradient_mask: - threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps - mask_bool = mask > threshhold # I don't know when mask got inverted, but it did + threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps + mask_bool = mask > threshold masked_input = torch.where(mask_bool, latents, mask_latents) else: masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) @@ -52,11 +74,13 @@ class InpaintExt(ExtensionBase): @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def init_tensors(self, ctx: DenoiseContext): if not self._is_normal_model(ctx.unet): - raise Exception("InpaintExt should be used only on normal models!") + raise ValueError("InpaintExt should be used only on normal models!") self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) self._noise = ctx.inputs.noise + # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner). + # We still need noise for inpainting, so we generate it from the seed here. if self._noise is None: self._noise = torch.randn( ctx.latents.shape, diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py index 9be259408f..b5a08a85a8 100644 --- a/invokeai/backend/stable_diffusion/extensions/inpaint_model.py +++ b/invokeai/backend/stable_diffusion/extensions/inpaint_model.py @@ -13,12 +13,26 @@ if TYPE_CHECKING: class InpaintModelExt(ExtensionBase): + """An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting + models. + """ def __init__( self, mask: Optional[torch.Tensor], masked_latents: Optional[torch.Tensor], is_gradient_mask: bool, ): + """Initialize InpaintModelExt. + Args: + mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are + expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be + inpainted. + masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area. + If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width) + is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range + from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or + 1. + """ super().__init__() if mask is not None and masked_latents is None: raise ValueError("Source image required for inpaint mask when inpaint model used!") @@ -29,12 +43,18 @@ class InpaintModelExt(ExtensionBase): @staticmethod def _is_inpaint_model(unet: UNet2DConditionModel): + """ Checks if the provided UNet belongs to a regular model. + The `in_channels` of a UNet vary depending on model type: + - normal - 4 + - depth - 5 + - inpaint - 9 + """ return unet.conv_in.in_channels == 9 @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def init_tensors(self, ctx: DenoiseContext): if not self._is_inpaint_model(ctx.unet): - raise Exception("InpaintModelExt should be used only on inpaint models!") + raise ValueError("InpaintModelExt should be used only on inpaint models!") if self._mask is None: self._mask = torch.ones_like(ctx.latents[:1, :1])