mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
9d1fcba415
commit
c323a760a5
@ -732,10 +732,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
dtype = TorchDevice.choose_torch_dtype()
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
|
||||||
if noise is not None:
|
|
||||||
noise = noise.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
conditioning_data = self.get_conditioning_data(
|
||||||
@ -768,21 +764,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
denoise_ctx = DenoiseContext(
|
|
||||||
inputs=DenoiseInputs(
|
|
||||||
orig_latents=latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
init_timestep=init_timestep,
|
|
||||||
noise=noise,
|
|
||||||
seed=seed,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
attention_processor_cls=CustomAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
unet=None,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
@ -799,6 +780,26 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
elif mask is not None:
|
elif mask is not None:
|
||||||
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||||
|
|
||||||
|
# Initialize context for modular denoise
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
denoise_ctx = DenoiseContext(
|
||||||
|
inputs=DenoiseInputs(
|
||||||
|
orig_latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
|
seed=seed,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
attention_processor_cls=CustomAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
unet=None,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# ext: t2i/ip adapter
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
|
|
||||||
|
@ -14,18 +14,40 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class InpaintExt(ExtensionBase):
|
class InpaintExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mask: torch.Tensor,
|
mask: torch.Tensor,
|
||||||
is_gradient_mask: bool,
|
is_gradient_mask: bool,
|
||||||
):
|
):
|
||||||
|
"""Initialize InpaintExt.
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._mask = mask
|
self._mask = mask
|
||||||
self._is_gradient_mask = is_gradient_mask
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
# Noise, which used to noisify unmasked part of image
|
||||||
|
# if noise provided to context, then it will be used
|
||||||
|
# if no noise provided, then noise will be generated based on seed
|
||||||
self._noise: Optional[torch.Tensor] = None
|
self._noise: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_normal_model(unet: UNet2DConditionModel):
|
def _is_normal_model(unet: UNet2DConditionModel):
|
||||||
|
""" Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
return unet.conv_in.in_channels == 4
|
return unet.conv_in.in_channels == 4
|
||||||
|
|
||||||
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
@ -42,8 +64,8 @@ class InpaintExt(ExtensionBase):
|
|||||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
if self._is_gradient_mask:
|
if self._is_gradient_mask:
|
||||||
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
mask_bool = mask > threshold
|
||||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
else:
|
else:
|
||||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
@ -52,11 +74,13 @@ class InpaintExt(ExtensionBase):
|
|||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
def init_tensors(self, ctx: DenoiseContext):
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
if not self._is_normal_model(ctx.unet):
|
if not self._is_normal_model(ctx.unet):
|
||||||
raise Exception("InpaintExt should be used only on normal models!")
|
raise ValueError("InpaintExt should be used only on normal models!")
|
||||||
|
|
||||||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
self._noise = ctx.inputs.noise
|
self._noise = ctx.inputs.noise
|
||||||
|
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
|
# We still need noise for inpainting, so we generate it from the seed here.
|
||||||
if self._noise is None:
|
if self._noise is None:
|
||||||
self._noise = torch.randn(
|
self._noise = torch.randn(
|
||||||
ctx.latents.shape,
|
ctx.latents.shape,
|
||||||
|
@ -13,12 +13,26 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class InpaintModelExt(ExtensionBase):
|
class InpaintModelExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mask: Optional[torch.Tensor],
|
mask: Optional[torch.Tensor],
|
||||||
masked_latents: Optional[torch.Tensor],
|
masked_latents: Optional[torch.Tensor],
|
||||||
is_gradient_mask: bool,
|
is_gradient_mask: bool,
|
||||||
):
|
):
|
||||||
|
"""Initialize InpaintModelExt.
|
||||||
|
Args:
|
||||||
|
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
||||||
|
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if mask is not None and masked_latents is None:
|
if mask is not None and masked_latents is None:
|
||||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
@ -29,12 +43,18 @@ class InpaintModelExt(ExtensionBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_inpaint_model(unet: UNet2DConditionModel):
|
def _is_inpaint_model(unet: UNet2DConditionModel):
|
||||||
|
""" Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
return unet.conv_in.in_channels == 9
|
return unet.conv_in.in_channels == 9
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
def init_tensors(self, ctx: DenoiseContext):
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
if not self._is_inpaint_model(ctx.unet):
|
if not self._is_inpaint_model(ctx.unet):
|
||||||
raise Exception("InpaintModelExt should be used only on inpaint models!")
|
raise ValueError("InpaintModelExt should be used only on inpaint models!")
|
||||||
|
|
||||||
if self._mask is None:
|
if self._mask is None:
|
||||||
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
||||||
|
Loading…
Reference in New Issue
Block a user