diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 03a0965276..68b46b11a7 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch import torchvision.transforms as tv_transforms @@ -20,6 +20,7 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.denoise import denoise +from invokeai.backend.flux.inpaint_extension import InpaintExtension from invokeai.backend.flux.model import Flux from invokeai.backend.flux.sampling_utils import ( generate_img_ids, @@ -31,8 +32,6 @@ from invokeai.backend.flux.sampling_utils import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice -EPS = 1e-6 - @invocation( "flux_text_to_image", @@ -51,6 +50,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): description=FieldDescriptions.latents, input=Input.Connection, ) + # denoise_mask is used for image-to-image inpainting. Only the masked region is modified. denoise_mask: Optional[DenoiseMaskField] = InputField( default=None, description=FieldDescriptions.denoise_mask, @@ -122,6 +122,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): transformer_info = context.models.load(self.transformer.transformer) is_schnell = "schnell" in transformer_info.config.config_path + # Calculate the timestep schedule. image_seq_len = noise.shape[-1] * noise.shape[-2] // 4 timesteps = get_schedule( num_steps=self.num_steps, @@ -130,7 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ) # Prepare input latent image. - if self.denoising_start > EPS: + if self.denoising_start > 1e-5: # If denoising_start > 0, we are doing image-to-image. if init_latents is None: raise ValueError("latents must be provided if denoising_start > 0.") @@ -144,16 +145,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): t_0 = timesteps[0] x = t_0 * noise + (1.0 - t_0) * init_latents else: - # We are not doing image-to-image, so we are starting from noise. + # We are not doing image-to-image, so start from noise. x = noise - # Prepare inpaint mask. inpaint_mask = self._prep_inpaint_mask(context, x) - if inpaint_mask is not None: - assert init_latents is not None - # Expand the inpaint mask to the same shape as the init_latents so that when we pack inpaint_mask it lines - # up with the init_latents. - inpaint_mask = inpaint_mask.expand_as(init_latents) b, _c, h, w = x.shape img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype) @@ -167,41 +162,22 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): noise = pack(noise) x = pack(x) - # Verify that we calculated the image_seq_len correctly. + # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly. assert image_seq_len == x.shape[1] + # Prepare inpaint extension. + inpaint_extension: InpaintExtension | None = None + if inpaint_mask is not None: + assert init_latents is not None + inpaint_extension = InpaintExtension( + init_latents=init_latents, + inpaint_mask=inpaint_mask, + noise=noise, + ) + with transformer_info as transformer: assert isinstance(transformer, Flux) - def step_callback() -> None: - if context.util.is_canceled(): - raise CanceledException - - # TODO: Make this look like the image before re-enabling - # latent_image = unpack(img.float(), self.height, self.width) - # latent_image = latent_image.squeeze() # Remove unnecessary dimensions - # flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] - - # # Create a new tensor of the required shape [255, 255, 3] - # latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format - - # # Convert to a NumPy array and then to a PIL Image - # image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) - - # (width, height) = image.size - # width *= 8 - # height *= 8 - - # dataURL = image_to_dataURL(image, image_format="JPEG") - - # # TODO: move this whole function to invocation context to properly reference these variables - # context._services.events.emit_invocation_denoise_progress( - # context._data.queue_item, - # context._data.invocation, - # state, - # ProgressImage(dataURL=dataURL, width=width, height=height), - # ) - x = denoise( model=transformer, img=x, @@ -210,29 +186,35 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): txt_ids=txt_ids, vec=clip_embeddings, timesteps=timesteps, - step_callback=step_callback, + step_callback=self._build_step_callback(context), guidance=self.guidance, - init_latents=init_latents, - noise=noise, - inpaint_mask=inpaint_mask, + inpaint_extension=inpaint_extension, ) x = unpack(x.float(), self.height, self.width) - return x def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None: """Prepare the inpaint mask. - Loads the mask, resizes if necessary, casts to same device/dtype as latents. + - Loads the mask + - Resizes if necessary + - Casts to same device/dtype as latents + - Expands mask to the same shape as latents so that they line up after 'packing' + + Args: + context (InvocationContext): The invocation context, for loading the inpaint mask. + latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape, + device, and dtype for the inpaint mask. Returns: - tuple[torch.Tensor | None, bool]: (mask, is_gradient_mask) + torch.Tensor | None: Inpaint mask. """ if self.denoise_mask is None: return None mask = context.tensors.load(self.denoise_mask.mask_name) + _, _, latent_height, latent_width = latents.shape mask = tv_resize( img=mask, @@ -240,5 +222,41 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): interpolation=tv_transforms.InterpolationMode.BILINEAR, antialias=False, ) + mask = mask.to(device=latents.device, dtype=latents.dtype) - return mask + + # Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with + # `latents`. + return mask.expand_as(latents) + + def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]: + def step_callback() -> None: + if context.util.is_canceled(): + raise CanceledException + + # TODO: Make this look like the image before re-enabling + # latent_image = unpack(img.float(), self.height, self.width) + # latent_image = latent_image.squeeze() # Remove unnecessary dimensions + # flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] + + # # Create a new tensor of the required shape [255, 255, 3] + # latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format + + # # Convert to a NumPy array and then to a PIL Image + # image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) + + # (width, height) = image.size + # width *= 8 + # height *= 8 + + # dataURL = image_to_dataURL(image, image_format="JPEG") + + # # TODO: move this whole function to invocation context to properly reference these variables + # context._services.events.emit_invocation_denoise_progress( + # context._data.queue_item, + # context._data.invocation, + # state, + # ProgressImage(dataURL=dataURL, width=width, height=height), + # ) + + return step_callback diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 103fcd907b..4fb9a792dd 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -3,7 +3,7 @@ from typing import Callable import torch from tqdm import tqdm -from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents +from invokeai.backend.flux.inpaint_extension import InpaintExtension from invokeai.backend.flux.model import Flux @@ -19,10 +19,7 @@ def denoise( timesteps: list[float], step_callback: Callable[[], None], guidance: float, - # For inpainting: - init_latents: torch.Tensor | None, - noise: torch.Tensor, - inpaint_mask: torch.Tensor | None, + inpaint_extension: InpaintExtension | None, ): # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) @@ -40,15 +37,8 @@ def denoise( img = img + (t_prev - t_curr) * pred - if inpaint_mask is not None: - assert init_latents is not None - img = merge_intermediate_latents_with_init_latents( - init_latents=init_latents, - intermediate_latents=img, - timestep=t_prev, - noise=noise, - inpaint_mask=inpaint_mask, - ) + if inpaint_extension is not None: + img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) step_callback() diff --git a/invokeai/backend/flux/inpaint.py b/invokeai/backend/flux/inpaint.py deleted file mode 100644 index 3bebb9c3e6..0000000000 --- a/invokeai/backend/flux/inpaint.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch - - -def merge_intermediate_latents_with_init_latents( - init_latents: torch.Tensor, - intermediate_latents: torch.Tensor, - timestep: float, - noise: torch.Tensor, - inpaint_mask: torch.Tensor, -) -> torch.Tensor: - # Noise the init_latents for the current timestep. - noised_init_latents = noise * timestep + (1.0 - timestep) * init_latents - - # Merge the intermediate_latents with the noised_init_latents using the inpaint_mask. - return intermediate_latents * inpaint_mask + noised_init_latents * (1.0 - inpaint_mask) diff --git a/invokeai/backend/flux/inpaint_extension.py b/invokeai/backend/flux/inpaint_extension.py new file mode 100644 index 0000000000..b6c634a3b5 --- /dev/null +++ b/invokeai/backend/flux/inpaint_extension.py @@ -0,0 +1,35 @@ +import torch + + +class InpaintExtension: + """A class for managing inpainting with FLUX.""" + + def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor): + """Initialize InpaintExtension. + + Args: + init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format. + inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be + re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the + inpainted region with the background. In 'packed' format. + noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format. + """ + assert init_latents.shape == inpaint_mask.shape == noise.shape + self._init_latents = init_latents + self._inpaint_mask = inpaint_mask + self._noise = noise + + def merge_intermediate_latents_with_init_latents( + self, intermediate_latents: torch.Tensor, timestep: float + ) -> torch.Tensor: + """Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e. + update the intermediate latents to keep the regions that are not being inpainted on the correct noise + trajectory. + + This function should be called after each denoising step. + """ + # Noise the init latents for the current timestep. + noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents + + # Merge the intermediate latents with the noised_init_latents using the inpaint_mask. + return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask) diff --git a/invokeai/backend/flux/sampling_utils.py b/invokeai/backend/flux/sampling_utils.py index 9d710015af..4be15c491f 100644 --- a/invokeai/backend/flux/sampling_utils.py +++ b/invokeai/backend/flux/sampling_utils.py @@ -5,7 +5,6 @@ from typing import Callable import torch from einops import rearrange, repeat -from torch import Tensor def get_noise( @@ -31,7 +30,7 @@ def get_noise( ).to(device=device, dtype=dtype) -def time_shift(mu: float, sigma: float, t: Tensor): +def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor: return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) @@ -60,7 +59,8 @@ def get_schedule( return timesteps.tolist() -def unpack(x: Tensor, height: int, width: int) -> Tensor: +def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: + """Unpack flat array of patch embeddings to latent image.""" return rearrange( x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", @@ -71,39 +71,27 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor: ) -def pack(x: Tensor) -> Tensor: +def pack(x: torch.Tensor) -> torch.Tensor: + """Pack latent image to flattented array of patch embeddings.""" # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) -def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor: +def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Generate tensor of image position ids. + + Args: + h (int): Height of image in latent space. + w (int): Width of image in latent space. + batch_size (int): Batch size. + device (torch.device): Device. + dtype (torch.dtype): dtype. + + Returns: + torch.Tensor: Image position ids. + """ img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) return img_ids - - -def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert an input image in latent space to patches for diffusion. - - This implementation was extracted from: - https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 - - Args: - img (torch.Tensor): Input image in latent space. - - Returns: - tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. - """ - bs, c, h, w = img.shape - - img = pack(img) - - # Generate patch position ids. - img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - - return img, img_ids