From 08633c3f045ada92f855647372752abf3124eb48 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 22 Aug 2024 17:18:43 +0000 Subject: [PATCH] Move prepare_latent_image_patches(...) to sampling.py with all of the related FLUX inference code. --- .../app/invocations/flux_text_to_image.py | 30 ++----------------- invokeai/backend/flux/sampling.py | 25 ++++++++++++++++ 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index f29b3dd309..2e80afc1e4 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange, repeat +from einops import rearrange from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation @@ -16,7 +16,7 @@ from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder -from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack +from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice @@ -87,7 +87,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): seed=self.seed, ) - img, img_ids = self._prepare_latent_img_patches(x) + img, img_ids = prepare_latent_img_patches(x) is_schnell = "schnell" in transformer_info.config.config_path @@ -123,30 +123,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): return x - def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert an input image in latent space to patches for diffusion. - - This implementation was extracted from: - https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 - - Returns: - tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. - """ - bs, c, h, w = latent_img.shape - - # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. - img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - if img.shape[0] == 1 and bs > 1: - img = repeat(img, "1 ... -> bs ...", bs=bs) - - # Generate patch position ids. - img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - - return img, img_ids - def _run_vae_decoding( self, context: InvocationContext, diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 82abc0e561..318a0bcdce 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -147,3 +147,28 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor: ph=2, pw=2, ) + + +def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert an input image in latent space to patches for diffusion. + + This implementation was extracted from: + https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 + + Returns: + tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. + """ + bs, c, h, w = latent_img.shape + + # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. + img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + # Generate patch position ids. + img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + return img, img_ids