Get a rough version of FLUX inpainting working.

This commit is contained in:
Ryan Dick 2024-08-29 19:05:44 +00:00
parent e0f12c762e
commit 7d854f32b0
6 changed files with 160 additions and 85 deletions

View File

@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
denoise_mask: Optional[DenoiseMaskField] = InputField( denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, default=None,
description=FieldDescriptions.mask, description=FieldDescriptions.denoise_mask,
input=Input.Connection, input=Input.Connection,
ui_order=8, ui_order=8,
) )

View File

@ -181,7 +181,7 @@ class FieldDescriptions:
) )
num_1 = "The first number" num_1 = "The first number"
num_2 = "The second number" num_2 = "The second number"
mask = "The mask to use for the operation" denoise_mask = "A mask of the region to apply the denoising process to."
board = "The board to save the image to" board = "The board to save the image to"
image = "The image to process" image = "The image to process"
tile_size = "Tile size" tile_size = "Tile size"

View File

@ -1,9 +1,12 @@
from typing import Optional from typing import Optional
import torch import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions, FieldDescriptions,
FluxConditioningField, FluxConditioningField,
Input, Input,
@ -16,8 +19,15 @@ from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.flux.sampling_utils import (
generate_img_ids,
get_noise,
get_schedule,
pack,
unpack,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -41,6 +51,11 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,
) )
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.denoise_mask,
input=Input.Connection,
)
denoising_start: float = InputField( denoising_start: float = InputField(
default=0.0, default=0.0,
ge=0, ge=0,
@ -95,7 +110,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype) init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
# Prepare input noise. # Prepare input noise.
x = get_noise( noise = get_noise(
num_samples=1, num_samples=1,
height=self.height, height=self.height,
width=self.width, width=self.width,
@ -107,14 +122,16 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
transformer_info = context.models.load(self.transformer.transformer) transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path is_schnell = "schnell" in transformer_info.config.config_path
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule( timesteps = get_schedule(
num_steps=self.num_steps, num_steps=self.num_steps,
image_seq_len=x.shape[1], image_seq_len=image_seq_len,
shift=not is_schnell, shift=not is_schnell,
) )
# Prepare inputs for image-to-image case. # Prepare input latent image.
if self.denoising_start > EPS: if self.denoising_start > EPS:
# If denoising_start > 0, we are doing image-to-image.
if init_latents is None: if init_latents is None:
raise ValueError("latents must be provided if denoising_start > 0.") raise ValueError("latents must be provided if denoising_start > 0.")
@ -125,13 +142,34 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# Noise the orig_latents by the appropriate amount for the first timestep. # Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0] t_0 = timesteps[0]
x = t_0 * x + (1.0 - t_0) * init_latents x = t_0 * noise + (1.0 - t_0) * init_latents
else:
# We are not doing image-to-image, so we are starting from noise.
x = noise
x, img_ids = prepare_latent_img_patches(x) # Prepare inpaint mask.
inpaint_mask = self._prep_inpaint_mask(context, x)
if inpaint_mask is not None:
assert init_latents is not None
# Expand the inpaint mask to the same shape as the init_latents so that when we pack inpaint_mask it lines
# up with the init_latents.
inpaint_mask = inpaint_mask.expand_as(init_latents)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)
# Verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1]
with transformer_info as transformer: with transformer_info as transformer:
assert isinstance(transformer, Flux) assert isinstance(transformer, Flux)
@ -174,8 +212,33 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
timesteps=timesteps, timesteps=timesteps,
step_callback=step_callback, step_callback=step_callback,
guidance=self.guidance, guidance=self.guidance,
init_latents=init_latents,
noise=noise,
inpaint_mask=inpaint_mask,
) )
x = unpack(x.float(), self.height, self.width) x = unpack(x.float(), self.height, self.width)
return x return x
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
Loads the mask, resizes if necessary, casts to same device/dtype as latents.
Returns:
tuple[torch.Tensor | None, bool]: (mask, is_gradient_mask)
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask

View File

@ -0,0 +1,55 @@
from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents
from invokeai.backend.flux.model import Flux
def denoise(
model: Flux,
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float,
# For inpainting:
init_latents: torch.Tensor | None,
noise: torch.Tensor,
inpaint_mask: torch.Tensor | None,
):
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
if inpaint_mask is not None:
assert init_latents is not None
img = merge_intermediate_latents_with_init_latents(
init_latents=init_latents,
intermediate_latents=img,
timestep=t_prev,
noise=noise,
inpaint_mask=inpaint_mask,
)
step_callback()
return img

View File

@ -0,0 +1,15 @@
import torch
def merge_intermediate_latents_with_init_latents(
init_latents: torch.Tensor,
intermediate_latents: torch.Tensor,
timestep: float,
noise: torch.Tensor,
inpaint_mask: torch.Tensor,
) -> torch.Tensor:
# Noise the init_latents for the current timestep.
noised_init_latents = noise * timestep + (1.0 - timestep) * init_latents
# Merge the intermediate_latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * inpaint_mask + noised_init_latents * (1.0 - inpaint_mask)

View File

@ -6,10 +6,6 @@ from typing import Callable
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import Tensor from torch import Tensor
from tqdm import tqdm
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder
def get_noise( def get_noise(
@ -35,40 +31,6 @@ def get_noise(
).to(device=device, dtype=dtype) ).to(device=device, dtype=dtype)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor): def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
@ -98,39 +60,6 @@ def get_schedule(
return timesteps.tolist() return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float = 4.0,
):
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
step_callback()
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor: def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange( return rearrange(
x, x,
@ -142,21 +71,34 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
) )
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def pack(x: Tensor) -> Tensor:
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor:
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion. """Convert an input image in latent space to patches for diffusion.
This implementation was extracted from: This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Args:
img (torch.Tensor): Input image in latent space.
Returns: Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
""" """
bs, c, h, w = latent_img.shape bs, c, h, w = img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. img = pack(img)
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids. # Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype) img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)