From 7d854f32b08012052db9cee57c92da8161827b00 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 29 Aug 2024 19:05:44 +0000 Subject: [PATCH] Get a rough version of FLUX inpainting working. --- invokeai/app/invocations/denoise_latents.py | 2 +- invokeai/app/invocations/fields.py | 2 +- .../app/invocations/flux_text_to_image.py | 75 +++++++++++++-- invokeai/backend/flux/denoise.py | 55 +++++++++++ invokeai/backend/flux/inpaint.py | 15 +++ .../flux/{sampling.py => sampling_utils.py} | 96 ++++--------------- 6 files changed, 160 insertions(+), 85 deletions(-) create mode 100644 invokeai/backend/flux/denoise.py create mode 100644 invokeai/backend/flux/inpaint.py rename invokeai/backend/flux/{sampling.py => sampling_utils.py} (53%) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index d97f92d42c..f8028b1933 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) denoise_mask: Optional[DenoiseMaskField] = InputField( default=None, - description=FieldDescriptions.mask, + description=FieldDescriptions.denoise_mask, input=Input.Connection, ui_order=8, ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 03654dd78d..bd841808f4 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -181,7 +181,7 @@ class FieldDescriptions: ) num_1 = "The first number" num_2 = "The second number" - mask = "The mask to use for the operation" + denoise_mask = "A mask of the region to apply the denoising process to." board = "The board to save the image to" image = "The image to process" tile_size = "Tile size" diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index a813243c85..03a0965276 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,9 +1,12 @@ from typing import Optional import torch +import torchvision.transforms as tv_transforms +from torchvision.transforms.functional import resize as tv_resize from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.fields import ( + DenoiseMaskField, FieldDescriptions, FluxConditioningField, Input, @@ -16,8 +19,15 @@ from invokeai.app.invocations.model import TransformerField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack +from invokeai.backend.flux.sampling_utils import ( + generate_img_ids, + get_noise, + get_schedule, + pack, + unpack, +) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice @@ -41,6 +51,11 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): description=FieldDescriptions.latents, input=Input.Connection, ) + denoise_mask: Optional[DenoiseMaskField] = InputField( + default=None, + description=FieldDescriptions.denoise_mask, + input=Input.Connection, + ) denoising_start: float = InputField( default=0.0, ge=0, @@ -95,7 +110,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype) # Prepare input noise. - x = get_noise( + noise = get_noise( num_samples=1, height=self.height, width=self.width, @@ -107,14 +122,16 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): transformer_info = context.models.load(self.transformer.transformer) is_schnell = "schnell" in transformer_info.config.config_path + image_seq_len = noise.shape[-1] * noise.shape[-2] // 4 timesteps = get_schedule( num_steps=self.num_steps, - image_seq_len=x.shape[1], + image_seq_len=image_seq_len, shift=not is_schnell, ) - # Prepare inputs for image-to-image case. + # Prepare input latent image. if self.denoising_start > EPS: + # If denoising_start > 0, we are doing image-to-image. if init_latents is None: raise ValueError("latents must be provided if denoising_start > 0.") @@ -125,13 +142,34 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): # Noise the orig_latents by the appropriate amount for the first timestep. t_0 = timesteps[0] - x = t_0 * x + (1.0 - t_0) * init_latents + x = t_0 * noise + (1.0 - t_0) * init_latents + else: + # We are not doing image-to-image, so we are starting from noise. + x = noise - x, img_ids = prepare_latent_img_patches(x) + # Prepare inpaint mask. + inpaint_mask = self._prep_inpaint_mask(context, x) + if inpaint_mask is not None: + assert init_latents is not None + # Expand the inpaint mask to the same shape as the init_latents so that when we pack inpaint_mask it lines + # up with the init_latents. + inpaint_mask = inpaint_mask.expand_as(init_latents) + + b, _c, h, w = x.shape + img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype) bs, t5_seq_len, _ = t5_embeddings.shape txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) + # Pack all latent tensors. + init_latents = pack(init_latents) if init_latents is not None else None + inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None + noise = pack(noise) + x = pack(x) + + # Verify that we calculated the image_seq_len correctly. + assert image_seq_len == x.shape[1] + with transformer_info as transformer: assert isinstance(transformer, Flux) @@ -174,8 +212,33 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): timesteps=timesteps, step_callback=step_callback, guidance=self.guidance, + init_latents=init_latents, + noise=noise, + inpaint_mask=inpaint_mask, ) x = unpack(x.float(), self.height, self.width) return x + + def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None: + """Prepare the inpaint mask. + + Loads the mask, resizes if necessary, casts to same device/dtype as latents. + + Returns: + tuple[torch.Tensor | None, bool]: (mask, is_gradient_mask) + """ + if self.denoise_mask is None: + return None + + mask = context.tensors.load(self.denoise_mask.mask_name) + _, _, latent_height, latent_width = latents.shape + mask = tv_resize( + img=mask, + size=[latent_height, latent_width], + interpolation=tv_transforms.InterpolationMode.BILINEAR, + antialias=False, + ) + mask = mask.to(device=latents.device, dtype=latents.dtype) + return mask diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py new file mode 100644 index 0000000000..103fcd907b --- /dev/null +++ b/invokeai/backend/flux/denoise.py @@ -0,0 +1,55 @@ +from typing import Callable + +import torch +from tqdm import tqdm + +from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents +from invokeai.backend.flux.model import Flux + + +def denoise( + model: Flux, + # model input + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + # sampling parameters + timesteps: list[float], + step_callback: Callable[[], None], + guidance: float, + # For inpainting: + init_latents: torch.Tensor | None, + noise: torch.Tensor, + inpaint_mask: torch.Tensor | None, +): + # guidance_vec is ignored for schnell. + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + + img = img + (t_prev - t_curr) * pred + + if inpaint_mask is not None: + assert init_latents is not None + img = merge_intermediate_latents_with_init_latents( + init_latents=init_latents, + intermediate_latents=img, + timestep=t_prev, + noise=noise, + inpaint_mask=inpaint_mask, + ) + + step_callback() + + return img diff --git a/invokeai/backend/flux/inpaint.py b/invokeai/backend/flux/inpaint.py new file mode 100644 index 0000000000..3bebb9c3e6 --- /dev/null +++ b/invokeai/backend/flux/inpaint.py @@ -0,0 +1,15 @@ +import torch + + +def merge_intermediate_latents_with_init_latents( + init_latents: torch.Tensor, + intermediate_latents: torch.Tensor, + timestep: float, + noise: torch.Tensor, + inpaint_mask: torch.Tensor, +) -> torch.Tensor: + # Noise the init_latents for the current timestep. + noised_init_latents = noise * timestep + (1.0 - timestep) * init_latents + + # Merge the intermediate_latents with the noised_init_latents using the inpaint_mask. + return intermediate_latents * inpaint_mask + noised_init_latents * (1.0 - inpaint_mask) diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling_utils.py similarity index 53% rename from invokeai/backend/flux/sampling.py rename to invokeai/backend/flux/sampling_utils.py index 3c3103411a..9d710015af 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling_utils.py @@ -6,10 +6,6 @@ from typing import Callable import torch from einops import rearrange, repeat from torch import Tensor -from tqdm import tqdm - -from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.modules.conditioner import HFEncoder def get_noise( @@ -35,40 +31,6 @@ def get_noise( ).to(device=device, dtype=dtype) -def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: - bs, c, h, w = img.shape - if bs == 1 and not isinstance(prompt, str): - bs = len(prompt) - - img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img.shape[0] == 1 and bs > 1: - img = repeat(img, "1 ... -> bs ...", bs=bs) - - img_ids = torch.zeros(h // 2, w // 2, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - - if isinstance(prompt, str): - prompt = [prompt] - txt = t5(prompt) - if txt.shape[0] == 1 and bs > 1: - txt = repeat(txt, "1 ... -> bs ...", bs=bs) - txt_ids = torch.zeros(bs, txt.shape[1], 3) - - vec = clip(prompt) - if vec.shape[0] == 1 and bs > 1: - vec = repeat(vec, "1 ... -> bs ...", bs=bs) - - return { - "img": img, - "img_ids": img_ids.to(img.device), - "txt": txt.to(img.device), - "txt_ids": txt_ids.to(img.device), - "vec": vec.to(img.device), - } - - def time_shift(mu: float, sigma: float, t: Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) @@ -98,39 +60,6 @@ def get_schedule( return timesteps.tolist() -def denoise( - model: Flux, - # model input - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - vec: Tensor, - # sampling parameters - timesteps: list[float], - step_callback: Callable[[], None], - guidance: float = 4.0, -): - # guidance_vec is ignored for schnell. - guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t_vec, - guidance=guidance_vec, - ) - - img = img + (t_prev - t_curr) * pred - step_callback() - - return img - - def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, @@ -142,21 +71,34 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor: ) -def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def pack(x: Tensor) -> Tensor: + # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. + return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + +def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor: + img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert an input image in latent space to patches for diffusion. This implementation was extracted from: https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 + Args: + img (torch.Tensor): Input image in latent space. + Returns: tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. """ - bs, c, h, w = latent_img.shape + bs, c, h, w = img.shape - # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. - img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img.shape[0] == 1 and bs > 1: - img = repeat(img, "1 ... -> bs ...", bs=bs) + img = pack(img) # Generate patch position ids. img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)