mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Same changes as in other PRs, add check for running inpainting on inpaint model without source image
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
58f3072b91
commit
5003e5d763
@ -718,7 +718,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return seed, noise, latents
|
return seed, noise, latents
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
if os.environ.get("USE_MODULAR_DENOISE", False):
|
if os.environ.get("USE_MODULAR_DENOISE", True):
|
||||||
return self._new_invoke(context)
|
return self._new_invoke(context)
|
||||||
else:
|
else:
|
||||||
return self._old_invoke(context)
|
return self._old_invoke(context)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
@ -20,8 +20,9 @@ class InpaintExt(ExtensionBase):
|
|||||||
is_gradient_mask: bool,
|
is_gradient_mask: bool,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mask = mask
|
self._mask = mask
|
||||||
self.is_gradient_mask = is_gradient_mask
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
self._noise: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_normal_model(unet: UNet2DConditionModel):
|
def _is_normal_model(unet: UNet2DConditionModel):
|
||||||
@ -29,18 +30,18 @@ class InpaintExt(ExtensionBase):
|
|||||||
|
|
||||||
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size = latents.size(0)
|
batch_size = latents.size(0)
|
||||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
if t.dim() == 0:
|
if t.dim() == 0:
|
||||||
# some schedulers expect t to be one-dimensional.
|
# some schedulers expect t to be one-dimensional.
|
||||||
# TODO: file diffusers bug about inconsistency?
|
# TODO: file diffusers bug about inconsistency?
|
||||||
t = einops.repeat(t, "-> batch", batch=batch_size)
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
# get very confused about what is happening from step to step when we do that.
|
# get very confused about what is happening from step to step when we do that.
|
||||||
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self.noise, t)
|
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
|
||||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
if self.is_gradient_mask:
|
if self._is_gradient_mask:
|
||||||
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
@ -53,11 +54,11 @@ class InpaintExt(ExtensionBase):
|
|||||||
if not self._is_normal_model(ctx.unet):
|
if not self._is_normal_model(ctx.unet):
|
||||||
raise Exception("InpaintExt should be used only on normal models!")
|
raise Exception("InpaintExt should be used only on normal models!")
|
||||||
|
|
||||||
self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
self.noise = ctx.inputs.noise
|
self._noise = ctx.inputs.noise
|
||||||
if self.noise is None:
|
if self._noise is None:
|
||||||
self.noise = torch.randn(
|
self._noise = torch.randn(
|
||||||
ctx.latents.shape,
|
ctx.latents.shape,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@ -85,7 +86,7 @@ class InpaintExt(ExtensionBase):
|
|||||||
# restore unmasked part after the last step is completed
|
# restore unmasked part after the last step is completed
|
||||||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
def restore_unmasked(self, ctx: DenoiseContext):
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
if self.is_gradient_mask:
|
if self._is_gradient_mask:
|
||||||
ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
||||||
else:
|
else:
|
||||||
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)
|
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
||||||
|
@ -20,9 +20,12 @@ class InpaintModelExt(ExtensionBase):
|
|||||||
is_gradient_mask: bool,
|
is_gradient_mask: bool,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mask = mask
|
if mask is not None and masked_latents is None:
|
||||||
self.masked_latents = masked_latents
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
self.is_gradient_mask = is_gradient_mask
|
|
||||||
|
self._mask = mask
|
||||||
|
self._masked_latents = masked_latents
|
||||||
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_inpaint_model(unet: UNet2DConditionModel):
|
def _is_inpaint_model(unet: UNet2DConditionModel):
|
||||||
@ -33,21 +36,21 @@ class InpaintModelExt(ExtensionBase):
|
|||||||
if not self._is_inpaint_model(ctx.unet):
|
if not self._is_inpaint_model(ctx.unet):
|
||||||
raise Exception("InpaintModelExt should be used only on inpaint models!")
|
raise Exception("InpaintModelExt should be used only on inpaint models!")
|
||||||
|
|
||||||
if self.mask is None:
|
if self._mask is None:
|
||||||
self.mask = torch.ones_like(ctx.latents[:1, :1])
|
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
||||||
self.mask = self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
if self.masked_latents is None:
|
if self._masked_latents is None:
|
||||||
self.masked_latents = torch.zeros_like(ctx.latents[:1])
|
self._masked_latents = torch.zeros_like(ctx.latents[:1])
|
||||||
self.masked_latents = self.masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
# TODO: any ideas about order value?
|
# TODO: any ideas about order value?
|
||||||
# do last so that other extensions works with normal latents
|
# do last so that other extensions works with normal latents
|
||||||
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
|
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
|
||||||
def append_inpaint_layers(self, ctx: DenoiseContext):
|
def append_inpaint_layers(self, ctx: DenoiseContext):
|
||||||
batch_size = ctx.unet_kwargs.sample.shape[0]
|
batch_size = ctx.unet_kwargs.sample.shape[0]
|
||||||
b_mask = torch.cat([self.mask] * batch_size)
|
b_mask = torch.cat([self._mask] * batch_size)
|
||||||
b_masked_latents = torch.cat([self.masked_latents] * batch_size)
|
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
|
||||||
ctx.unet_kwargs.sample = torch.cat(
|
ctx.unet_kwargs.sample = torch.cat(
|
||||||
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
|
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
|
||||||
dim=1,
|
dim=1,
|
||||||
@ -57,10 +60,7 @@ class InpaintModelExt(ExtensionBase):
|
|||||||
# restore unmasked part as inpaint model can change unmasked part slightly
|
# restore unmasked part as inpaint model can change unmasked part slightly
|
||||||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
def restore_unmasked(self, ctx: DenoiseContext):
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
if self.mask is None:
|
if self._is_gradient_mask:
|
||||||
return
|
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
|
||||||
if self.is_gradient_mask:
|
|
||||||
ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
|
||||||
else:
|
else:
|
||||||
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self.mask)
|
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
||||||
|
Loading…
Reference in New Issue
Block a user