mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Code cleanup and documentation around FLUX inpainting.
This commit is contained in:
parent
262b67b9cb
commit
75d0558241
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as tv_transforms
|
import torchvision.transforms as tv_transforms
|
||||||
@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
|
|||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.flux.denoise import denoise
|
from invokeai.backend.flux.denoise import denoise
|
||||||
|
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.sampling_utils import (
|
from invokeai.backend.flux.sampling_utils import (
|
||||||
generate_img_ids,
|
generate_img_ids,
|
||||||
@ -31,8 +32,6 @@ from invokeai.backend.flux.sampling_utils import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
EPS = 1e-6
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"flux_text_to_image",
|
"flux_text_to_image",
|
||||||
@ -51,6 +50,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
|
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.denoise_mask,
|
description=FieldDescriptions.denoise_mask,
|
||||||
@ -122,6 +122,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
transformer_info = context.models.load(self.transformer.transformer)
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
|
# Calculate the timestep schedule.
|
||||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
@ -130,7 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare input latent image.
|
# Prepare input latent image.
|
||||||
if self.denoising_start > EPS:
|
if self.denoising_start > 1e-5:
|
||||||
# If denoising_start > 0, we are doing image-to-image.
|
# If denoising_start > 0, we are doing image-to-image.
|
||||||
if init_latents is None:
|
if init_latents is None:
|
||||||
raise ValueError("latents must be provided if denoising_start > 0.")
|
raise ValueError("latents must be provided if denoising_start > 0.")
|
||||||
@ -144,16 +145,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
t_0 = timesteps[0]
|
t_0 = timesteps[0]
|
||||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||||
else:
|
else:
|
||||||
# We are not doing image-to-image, so we are starting from noise.
|
# We are not doing image-to-image, so start from noise.
|
||||||
x = noise
|
x = noise
|
||||||
|
|
||||||
# Prepare inpaint mask.
|
|
||||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||||
if inpaint_mask is not None:
|
|
||||||
assert init_latents is not None
|
|
||||||
# Expand the inpaint mask to the same shape as the init_latents so that when we pack inpaint_mask it lines
|
|
||||||
# up with the init_latents.
|
|
||||||
inpaint_mask = inpaint_mask.expand_as(init_latents)
|
|
||||||
|
|
||||||
b, _c, h, w = x.shape
|
b, _c, h, w = x.shape
|
||||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||||
@ -167,12 +162,74 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
noise = pack(noise)
|
noise = pack(noise)
|
||||||
x = pack(x)
|
x = pack(x)
|
||||||
|
|
||||||
# Verify that we calculated the image_seq_len correctly.
|
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||||
assert image_seq_len == x.shape[1]
|
assert image_seq_len == x.shape[1]
|
||||||
|
|
||||||
|
# Prepare inpaint extension.
|
||||||
|
inpaint_extension: InpaintExtension | None = None
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
assert init_latents is not None
|
||||||
|
inpaint_extension = InpaintExtension(
|
||||||
|
init_latents=init_latents,
|
||||||
|
inpaint_mask=inpaint_mask,
|
||||||
|
noise=noise,
|
||||||
|
)
|
||||||
|
|
||||||
with transformer_info as transformer:
|
with transformer_info as transformer:
|
||||||
assert isinstance(transformer, Flux)
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
|
x = denoise(
|
||||||
|
model=transformer,
|
||||||
|
img=x,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=t5_embeddings,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
vec=clip_embeddings,
|
||||||
|
timesteps=timesteps,
|
||||||
|
step_callback=self._build_step_callback(context),
|
||||||
|
guidance=self.guidance,
|
||||||
|
inpaint_extension=inpaint_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = unpack(x.float(), self.height, self.width)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||||
|
"""Prepare the inpaint mask.
|
||||||
|
|
||||||
|
- Loads the mask
|
||||||
|
- Resizes if necessary
|
||||||
|
- Casts to same device/dtype as latents
|
||||||
|
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||||
|
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||||
|
device, and dtype for the inpaint mask.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor | None: Inpaint mask.
|
||||||
|
"""
|
||||||
|
if self.denoise_mask is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||||
|
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
mask = tv_resize(
|
||||||
|
img=mask,
|
||||||
|
size=[latent_height, latent_width],
|
||||||
|
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||||
|
antialias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||||
|
# `latents`.
|
||||||
|
return mask.expand_as(latents)
|
||||||
|
|
||||||
|
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]:
|
||||||
def step_callback() -> None:
|
def step_callback() -> None:
|
||||||
if context.util.is_canceled():
|
if context.util.is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
@ -202,43 +259,4 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||||
# )
|
# )
|
||||||
|
|
||||||
x = denoise(
|
return step_callback
|
||||||
model=transformer,
|
|
||||||
img=x,
|
|
||||||
img_ids=img_ids,
|
|
||||||
txt=t5_embeddings,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
vec=clip_embeddings,
|
|
||||||
timesteps=timesteps,
|
|
||||||
step_callback=step_callback,
|
|
||||||
guidance=self.guidance,
|
|
||||||
init_latents=init_latents,
|
|
||||||
noise=noise,
|
|
||||||
inpaint_mask=inpaint_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = unpack(x.float(), self.height, self.width)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
|
||||||
"""Prepare the inpaint mask.
|
|
||||||
|
|
||||||
Loads the mask, resizes if necessary, casts to same device/dtype as latents.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[torch.Tensor | None, bool]: (mask, is_gradient_mask)
|
|
||||||
"""
|
|
||||||
if self.denoise_mask is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
|
||||||
mask = tv_resize(
|
|
||||||
img=mask,
|
|
||||||
size=[latent_height, latent_width],
|
|
||||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
|
||||||
antialias=False,
|
|
||||||
)
|
|
||||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
|
||||||
return mask
|
|
||||||
|
@ -3,7 +3,7 @@ from typing import Callable
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents
|
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
|
|
||||||
|
|
||||||
@ -19,10 +19,7 @@ def denoise(
|
|||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
step_callback: Callable[[], None],
|
step_callback: Callable[[], None],
|
||||||
guidance: float,
|
guidance: float,
|
||||||
# For inpainting:
|
inpaint_extension: InpaintExtension | None,
|
||||||
init_latents: torch.Tensor | None,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
inpaint_mask: torch.Tensor | None,
|
|
||||||
):
|
):
|
||||||
# guidance_vec is ignored for schnell.
|
# guidance_vec is ignored for schnell.
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
@ -40,15 +37,8 @@ def denoise(
|
|||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
if inpaint_mask is not None:
|
if inpaint_extension is not None:
|
||||||
assert init_latents is not None
|
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||||
img = merge_intermediate_latents_with_init_latents(
|
|
||||||
init_latents=init_latents,
|
|
||||||
intermediate_latents=img,
|
|
||||||
timestep=t_prev,
|
|
||||||
noise=noise,
|
|
||||||
inpaint_mask=inpaint_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
step_callback()
|
step_callback()
|
||||||
|
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def merge_intermediate_latents_with_init_latents(
|
|
||||||
init_latents: torch.Tensor,
|
|
||||||
intermediate_latents: torch.Tensor,
|
|
||||||
timestep: float,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
inpaint_mask: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Noise the init_latents for the current timestep.
|
|
||||||
noised_init_latents = noise * timestep + (1.0 - timestep) * init_latents
|
|
||||||
|
|
||||||
# Merge the intermediate_latents with the noised_init_latents using the inpaint_mask.
|
|
||||||
return intermediate_latents * inpaint_mask + noised_init_latents * (1.0 - inpaint_mask)
|
|
35
invokeai/backend/flux/inpaint_extension.py
Normal file
35
invokeai/backend/flux/inpaint_extension.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintExtension:
|
||||||
|
"""A class for managing inpainting with FLUX."""
|
||||||
|
|
||||||
|
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||||
|
"""Initialize InpaintExtension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||||
|
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||||
|
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||||
|
inpainted region with the background. In 'packed' format.
|
||||||
|
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||||
|
"""
|
||||||
|
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||||
|
self._init_latents = init_latents
|
||||||
|
self._inpaint_mask = inpaint_mask
|
||||||
|
self._noise = noise
|
||||||
|
|
||||||
|
def merge_intermediate_latents_with_init_latents(
|
||||||
|
self, intermediate_latents: torch.Tensor, timestep: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||||
|
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||||
|
trajectory.
|
||||||
|
|
||||||
|
This function should be called after each denoising step.
|
||||||
|
"""
|
||||||
|
# Noise the init latents for the current timestep.
|
||||||
|
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||||
|
|
||||||
|
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||||
|
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
@ -5,7 +5,6 @@ from typing import Callable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def get_noise(
|
def get_noise(
|
||||||
@ -31,7 +30,7 @@ def get_noise(
|
|||||||
).to(device=device, dtype=dtype)
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
@ -60,7 +59,8 @@ def get_schedule(
|
|||||||
return timesteps.tolist()
|
return timesteps.tolist()
|
||||||
|
|
||||||
|
|
||||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""Unpack flat array of patch embeddings to latent image."""
|
||||||
return rearrange(
|
return rearrange(
|
||||||
x,
|
x,
|
||||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||||
@ -71,39 +71,27 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pack(x: Tensor) -> Tensor:
|
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Pack latent image to flattented array of patch embeddings."""
|
||||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
|
||||||
|
|
||||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor:
|
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
"""Generate tensor of image position ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (int): Height of image in latent space.
|
||||||
|
w (int): Width of image in latent space.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
device (torch.device): Device.
|
||||||
|
dtype (torch.dtype): dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Image position ids.
|
||||||
|
"""
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||||
return img_ids
|
return img_ids
|
||||||
|
|
||||||
|
|
||||||
def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Convert an input image in latent space to patches for diffusion.
|
|
||||||
|
|
||||||
This implementation was extracted from:
|
|
||||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img (torch.Tensor): Input image in latent space.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
|
||||||
"""
|
|
||||||
bs, c, h, w = img.shape
|
|
||||||
|
|
||||||
img = pack(img)
|
|
||||||
|
|
||||||
# Generate patch position ids.
|
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
return img, img_ids
|
|
||||||
|
Loading…
Reference in New Issue
Block a user