mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Code cleanup and documentation around FLUX inpainting.
This commit is contained in:
parent
262b67b9cb
commit
75d0558241
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
generate_img_ids,
|
||||
@ -31,8 +32,6 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
@ -51,6 +50,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
@ -122,6 +122,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
@ -130,7 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
# Prepare input latent image.
|
||||
if self.denoising_start > EPS:
|
||||
if self.denoising_start > 1e-5:
|
||||
# If denoising_start > 0, we are doing image-to-image.
|
||||
if init_latents is None:
|
||||
raise ValueError("latents must be provided if denoising_start > 0.")
|
||||
@ -144,16 +145,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# We are not doing image-to-image, so we are starting from noise.
|
||||
# We are not doing image-to-image, so start from noise.
|
||||
x = noise
|
||||
|
||||
# Prepare inpaint mask.
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
# Expand the inpaint mask to the same shape as the init_latents so that when we pack inpaint_mask it lines
|
||||
# up with the init_latents.
|
||||
inpaint_mask = inpaint_mask.expand_as(init_latents)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
@ -167,12 +162,74 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Verify that we calculated the image_seq_len correctly.
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||
device, and dtype for the inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]:
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
@ -202,43 +259,4 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=step_callback,
|
||||
guidance=self.guidance,
|
||||
init_latents=init_latents,
|
||||
noise=noise,
|
||||
inpaint_mask=inpaint_mask,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
Loads the mask, resizes if necessary, casts to same device/dtype as latents.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor | None, bool]: (mask, is_gradient_mask)
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
return step_callback
|
||||
|
@ -3,7 +3,7 @@ from typing import Callable
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
|
||||
|
||||
@ -19,10 +19,7 @@ def denoise(
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float,
|
||||
# For inpainting:
|
||||
init_latents: torch.Tensor | None,
|
||||
noise: torch.Tensor,
|
||||
inpaint_mask: torch.Tensor | None,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
@ -40,15 +37,8 @@ def denoise(
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
img = merge_intermediate_latents_with_init_latents(
|
||||
init_latents=init_latents,
|
||||
intermediate_latents=img,
|
||||
timestep=t_prev,
|
||||
noise=noise,
|
||||
inpaint_mask=inpaint_mask,
|
||||
)
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
|
||||
step_callback()
|
||||
|
||||
|
@ -1,15 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
init_latents: torch.Tensor,
|
||||
intermediate_latents: torch.Tensor,
|
||||
timestep: float,
|
||||
noise: torch.Tensor,
|
||||
inpaint_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Noise the init_latents for the current timestep.
|
||||
noised_init_latents = noise * timestep + (1.0 - timestep) * init_latents
|
||||
|
||||
# Merge the intermediate_latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * inpaint_mask + noised_init_latents * (1.0 - inpaint_mask)
|
35
invokeai/backend/flux/inpaint_extension.py
Normal file
35
invokeai/backend/flux/inpaint_extension.py
Normal file
@ -0,0 +1,35 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with FLUX."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background. In 'packed' format.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||
"""
|
||||
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, timestep: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
@ -5,7 +5,6 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_noise(
|
||||
@ -31,7 +30,7 @@ def get_noise(
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
@ -60,7 +59,8 @@ def get_schedule(
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack flat array of patch embeddings to latent image."""
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
@ -71,39 +71,27 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
)
|
||||
|
||||
|
||||
def pack(x: Tensor) -> Tensor:
|
||||
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack latent image to flattented array of patch embeddings."""
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor:
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
|
||||
|
||||
def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Input image in latent space.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = img.shape
|
||||
|
||||
img = pack(img)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
|
Loading…
Reference in New Issue
Block a user