Code cleanup and documentation around FLUX inpainting.

This commit is contained in:
Ryan Dick 2024-08-30 14:46:04 +00:00
parent 262b67b9cb
commit 75d0558241
5 changed files with 124 additions and 108 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Callable, Optional
import torch import torch
import torchvision.transforms as tv_transforms import torchvision.transforms as tv_transforms
@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import ( from invokeai.backend.flux.sampling_utils import (
generate_img_ids, generate_img_ids,
@ -31,8 +32,6 @@ from invokeai.backend.flux.sampling_utils import (
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
EPS = 1e-6
@invocation( @invocation(
"flux_text_to_image", "flux_text_to_image",
@ -51,6 +50,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,
) )
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField( denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, default=None,
description=FieldDescriptions.denoise_mask, description=FieldDescriptions.denoise_mask,
@ -122,6 +122,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
transformer_info = context.models.load(self.transformer.transformer) transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4 image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule( timesteps = get_schedule(
num_steps=self.num_steps, num_steps=self.num_steps,
@ -130,7 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
) )
# Prepare input latent image. # Prepare input latent image.
if self.denoising_start > EPS: if self.denoising_start > 1e-5:
# If denoising_start > 0, we are doing image-to-image. # If denoising_start > 0, we are doing image-to-image.
if init_latents is None: if init_latents is None:
raise ValueError("latents must be provided if denoising_start > 0.") raise ValueError("latents must be provided if denoising_start > 0.")
@ -144,16 +145,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
t_0 = timesteps[0] t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents x = t_0 * noise + (1.0 - t_0) * init_latents
else: else:
# We are not doing image-to-image, so we are starting from noise. # We are not doing image-to-image, so start from noise.
x = noise x = noise
# Prepare inpaint mask.
inpaint_mask = self._prep_inpaint_mask(context, x) inpaint_mask = self._prep_inpaint_mask(context, x)
if inpaint_mask is not None:
assert init_latents is not None
# Expand the inpaint mask to the same shape as the init_latents so that when we pack inpaint_mask it lines
# up with the init_latents.
inpaint_mask = inpaint_mask.expand_as(init_latents)
b, _c, h, w = x.shape b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype) img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
@ -167,12 +162,74 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
noise = pack(noise) noise = pack(noise)
x = pack(x) x = pack(x)
# Verify that we calculated the image_seq_len correctly. # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1] assert image_seq_len == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
with transformer_info as transformer: with transformer_info as transformer:
assert isinstance(transformer, Flux) assert isinstance(transformer, Flux)
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
)
x = unpack(x.float(), self.height, self.width)
return x
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
- Expands mask to the same shape as latents so that they line up after 'packing'
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
device, and dtype for the inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
# `latents`.
return mask.expand_as(latents)
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]:
def step_callback() -> None: def step_callback() -> None:
if context.util.is_canceled(): if context.util.is_canceled():
raise CanceledException raise CanceledException
@ -202,43 +259,4 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# ProgressImage(dataURL=dataURL, width=width, height=height), # ProgressImage(dataURL=dataURL, width=width, height=height),
# ) # )
x = denoise( return step_callback
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
init_latents=init_latents,
noise=noise,
inpaint_mask=inpaint_mask,
)
x = unpack(x.float(), self.height, self.width)
return x
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
Loads the mask, resizes if necessary, casts to same device/dtype as latents.
Returns:
tuple[torch.Tensor | None, bool]: (mask, is_gradient_mask)
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask

View File

@ -3,7 +3,7 @@ from typing import Callable
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
@ -19,10 +19,7 @@ def denoise(
timesteps: list[float], timesteps: list[float],
step_callback: Callable[[], None], step_callback: Callable[[], None],
guidance: float, guidance: float,
# For inpainting: inpaint_extension: InpaintExtension | None,
init_latents: torch.Tensor | None,
noise: torch.Tensor,
inpaint_mask: torch.Tensor | None,
): ):
# guidance_vec is ignored for schnell. # guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
@ -40,15 +37,8 @@ def denoise(
img = img + (t_prev - t_curr) * pred img = img + (t_prev - t_curr) * pred
if inpaint_mask is not None: if inpaint_extension is not None:
assert init_latents is not None img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
img = merge_intermediate_latents_with_init_latents(
init_latents=init_latents,
intermediate_latents=img,
timestep=t_prev,
noise=noise,
inpaint_mask=inpaint_mask,
)
step_callback() step_callback()

View File

@ -1,15 +0,0 @@
import torch
def merge_intermediate_latents_with_init_latents(
init_latents: torch.Tensor,
intermediate_latents: torch.Tensor,
timestep: float,
noise: torch.Tensor,
inpaint_mask: torch.Tensor,
) -> torch.Tensor:
# Noise the init_latents for the current timestep.
noised_init_latents = noise * timestep + (1.0 - timestep) * init_latents
# Merge the intermediate_latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * inpaint_mask + noised_init_latents * (1.0 - inpaint_mask)

View File

@ -0,0 +1,35 @@
import torch
class InpaintExtension:
"""A class for managing inpainting with FLUX."""
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
"""Initialize InpaintExtension.
Args:
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
inpainted region with the background. In 'packed' format.
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
"""
assert init_latents.shape == inpaint_mask.shape == noise.shape
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, timestep: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
trajectory.
This function should be called after each denoising step.
"""
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)

View File

@ -5,7 +5,6 @@ from typing import Callable
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import Tensor
def get_noise( def get_noise(
@ -31,7 +30,7 @@ def get_noise(
).to(device=device, dtype=dtype) ).to(device=device, dtype=dtype)
def time_shift(mu: float, sigma: float, t: Tensor): def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
@ -60,7 +59,8 @@ def get_schedule(
return timesteps.tolist() return timesteps.tolist()
def unpack(x: Tensor, height: int, width: int) -> Tensor: def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack flat array of patch embeddings to latent image."""
return rearrange( return rearrange(
x, x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)", "b (h w) (c ph pw) -> b c (h ph) (w pw)",
@ -71,39 +71,27 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
) )
def pack(x: Tensor) -> Tensor: def pack(x: torch.Tensor) -> torch.Tensor:
"""Pack latent image to flattented array of patch embeddings."""
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor: def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Generate tensor of image position ids.
Args:
h (int): Height of image in latent space.
w (int): Width of image in latent space.
batch_size (int): Batch size.
device (torch.device): Device.
dtype (torch.dtype): dtype.
Returns:
torch.Tensor: Image position ids.
"""
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None] img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids return img_ids
def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Args:
img (torch.Tensor): Input image in latent space.
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = img.shape
img = pack(img)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids