mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Simplify handling of inpainting models. Improve the in-code documentation around inpainting.
This commit is contained in:
parent
875673c9ba
commit
22704dd542
@ -38,40 +38,6 @@ class PipelineIntermediateState:
|
|||||||
predicted_original: Optional[torch.Tensor] = None
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AddsMaskLatents:
|
|
||||||
"""Add the channels required for inpainting model input.
|
|
||||||
|
|
||||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
|
||||||
and the latent encoding of the base image.
|
|
||||||
|
|
||||||
This class assumes the same mask and base image should apply to all items in the batch.
|
|
||||||
"""
|
|
||||||
|
|
||||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
|
||||||
mask: torch.Tensor
|
|
||||||
initial_image_latents: torch.Tensor
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
text_embeddings: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
model_input = self.add_mask_channels(latents)
|
|
||||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
|
||||||
|
|
||||||
def add_mask_channels(self, latents):
|
|
||||||
batch_size = latents.size(0)
|
|
||||||
# duplicate mask and latents for each batch
|
|
||||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
||||||
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
||||||
# add mask and image as additional channels
|
|
||||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
|
||||||
return model_input
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddsMaskGuidance:
|
class AddsMaskGuidance:
|
||||||
mask: torch.Tensor
|
mask: torch.Tensor
|
||||||
@ -273,6 +239,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||||
raise Exception("Should not be called")
|
raise Exception("Should not be called")
|
||||||
|
|
||||||
|
def add_inpainting_channels_to_latents(
|
||||||
|
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
|
||||||
|
):
|
||||||
|
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
|
||||||
|
|
||||||
|
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
|
||||||
|
input of shape (N, 9, H, W). The 9 channels are defined as follows:
|
||||||
|
- Channel 0-3: The latents being denoised.
|
||||||
|
- Channel 4: The mask indicating which parts of the image are being inpainted.
|
||||||
|
- Channel 5-8: The latent representation of the masked reference image being inpainted.
|
||||||
|
|
||||||
|
This function assumes that the same mask and base image should apply to all items in the batch.
|
||||||
|
"""
|
||||||
|
# Validate assumptions about input tensor shapes.
|
||||||
|
batch_size, latent_channels, latent_height, latent_width = latents.shape
|
||||||
|
assert latent_channels == 4
|
||||||
|
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
|
||||||
|
assert inpainting_mask == [1, 1, latent_height, latent_width]
|
||||||
|
|
||||||
|
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
|
||||||
|
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
|
||||||
|
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
|
||||||
|
|
||||||
|
# Concatenate along the channel dimension.
|
||||||
|
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
|
||||||
|
|
||||||
def latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
@ -302,94 +294,94 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
|
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||||
|
# full noise. Investigate the history of why this got commented out.
|
||||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||||
|
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
|
|
||||||
|
# Handle mask guidance (a.k.a. inpainting).
|
||||||
mask_guidance: AddsMaskGuidance | None = None
|
mask_guidance: AddsMaskGuidance | None = None
|
||||||
if mask is not None:
|
if mask is not None and not is_inpainting_model(self.unet):
|
||||||
if is_inpainting_model(self.unet):
|
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
|
||||||
if masked_latents is None:
|
# apply mask guidance to the latents.
|
||||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
|
||||||
|
|
||||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
self._unet_forward, mask, masked_latents
|
# We still need noise for inpainting, so we generate it from the seed here.
|
||||||
)
|
if noise is None:
|
||||||
else:
|
noise = torch.randn(
|
||||||
# if no noise provided, noisify unmasked area based on seed
|
orig_latents.shape,
|
||||||
if noise is None:
|
dtype=torch.float32,
|
||||||
noise = torch.randn(
|
device="cpu",
|
||||||
orig_latents.shape,
|
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||||
dtype=torch.float32,
|
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||||
device="cpu",
|
|
||||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
|
||||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
|
||||||
|
|
||||||
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask)
|
mask_guidance = AddsMaskGuidance(
|
||||||
|
mask=mask,
|
||||||
try:
|
mask_latents=orig_latents,
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
scheduler=self.scheduler,
|
||||||
use_regional_prompting = (
|
noise=noise,
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
is_gradient_mask=is_gradient_mask,
|
||||||
)
|
)
|
||||||
unet_attention_patcher = None
|
|
||||||
attn_ctx = nullcontext()
|
|
||||||
|
|
||||||
if use_ip_adapter or use_regional_prompting:
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
use_regional_prompting = (
|
||||||
[
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks}
|
)
|
||||||
for ipa in ip_adapter_data
|
unet_attention_patcher = None
|
||||||
]
|
attn_ctx = nullcontext()
|
||||||
if use_ip_adapter
|
|
||||||
else None
|
if use_ip_adapter or use_regional_prompting:
|
||||||
|
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||||
|
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||||
|
if use_ip_adapter
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
|
with attn_ctx:
|
||||||
|
callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=-1,
|
||||||
|
order=self.scheduler.order,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
|
latents=latents,
|
||||||
)
|
)
|
||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
)
|
||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
batched_t = t.expand(batch_size)
|
||||||
|
step_output = self.step(
|
||||||
|
t=batched_t,
|
||||||
|
latents=latents,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
step_index=i,
|
||||||
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
mask_guidance=mask_guidance,
|
||||||
|
mask=mask,
|
||||||
|
masked_latents=masked_latents,
|
||||||
|
control_data=control_data,
|
||||||
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
|
)
|
||||||
|
latents = step_output.prev_sample
|
||||||
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
with attn_ctx:
|
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
step=-1,
|
step=i,
|
||||||
order=self.scheduler.order,
|
order=self.scheduler.order,
|
||||||
total_steps=len(timesteps),
|
total_steps=len(timesteps),
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
timestep=int(t),
|
||||||
latents=latents,
|
latents=latents,
|
||||||
|
predicted_original=predicted_original,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
|
||||||
batched_t = t.expand(batch_size)
|
|
||||||
step_output = self.step(
|
|
||||||
batched_t,
|
|
||||||
latents,
|
|
||||||
conditioning_data,
|
|
||||||
step_index=i,
|
|
||||||
total_step_count=len(timesteps),
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
mask_guidance=mask_guidance,
|
|
||||||
control_data=control_data,
|
|
||||||
ip_adapter_data=ip_adapter_data,
|
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
|
||||||
)
|
|
||||||
latents = step_output.prev_sample
|
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
|
||||||
|
|
||||||
callback(
|
|
||||||
PipelineIntermediateState(
|
|
||||||
step=i,
|
|
||||||
order=self.scheduler.order,
|
|
||||||
total_steps=len(timesteps),
|
|
||||||
timestep=int(t),
|
|
||||||
latents=latents,
|
|
||||||
predicted_original=predicted_original,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
|
||||||
|
|
||||||
# restore unmasked part after the last step is completed
|
# restore unmasked part after the last step is completed
|
||||||
# in-process masking happens before each step
|
# in-process masking happens before each step
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -411,7 +403,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
mask_guidance: AddsMaskGuidance | None = None,
|
mask_guidance: AddsMaskGuidance | None,
|
||||||
|
mask: torch.Tensor | None,
|
||||||
|
masked_latents: torch.Tensor | None,
|
||||||
control_data: list[ControlNetData] | None = None,
|
control_data: list[ControlNetData] | None = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
@ -419,7 +413,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
|
|
||||||
|
# Handle masked image-to-image (a.k.a inpainting).
|
||||||
if mask_guidance is not None:
|
if mask_guidance is not None:
|
||||||
|
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
|
||||||
latents = mask_guidance(latents, timestep)
|
latents = mask_guidance(latents, timestep)
|
||||||
|
|
||||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
@ -468,6 +464,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
down_intrablock_additional_residuals = accum_adapter_state
|
down_intrablock_additional_residuals = accum_adapter_state
|
||||||
|
|
||||||
|
# Handle inpainting models.
|
||||||
|
if is_inpainting_model(self.unet):
|
||||||
|
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
|
||||||
|
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
|
||||||
|
# latents.
|
||||||
|
if mask is not None:
|
||||||
|
if masked_latents is None:
|
||||||
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
|
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||||
|
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
|
||||||
|
# We generate a global mask and empty original image so that we can still generate in this
|
||||||
|
# configuration.
|
||||||
|
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
|
||||||
|
# to do this.
|
||||||
|
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
|
||||||
|
# mask and original image once rather than on every denoising step.
|
||||||
|
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||||
|
latents=latent_model_input,
|
||||||
|
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
|
||||||
|
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
|
||||||
|
)
|
||||||
|
|
||||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||||
|
Loading…
Reference in New Issue
Block a user