Move prepare_latent_image_patches(...) to sampling.py with all of the related FLUX inference code.

This commit is contained in:
Ryan Dick 2024-08-22 17:18:43 +00:00 committed by Brandon
parent 25c91efbb6
commit 14ab339b33
2 changed files with 28 additions and 27 deletions

View File

@ -1,5 +1,5 @@
import torch import torch
from einops import rearrange, repeat from einops import rearrange
from PIL import Image from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
@ -16,7 +16,7 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -87,7 +87,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed, seed=self.seed,
) )
img, img_ids = self._prepare_latent_img_patches(x) img, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path is_schnell = "schnell" in transformer_info.config.config_path
@ -123,30 +123,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return x return x
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids
def _run_vae_decoding( def _run_vae_decoding(
self, self,
context: InvocationContext, context: InvocationContext,

View File

@ -147,3 +147,28 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
ph=2, ph=2,
pw=2, pw=2,
) )
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids