mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move prepare_latent_image_patches(...) to sampling.py with all of the related FLUX inference code.
This commit is contained in:
parent
25c91efbb6
commit
14ab339b33
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
@ -16,7 +16,7 @@ from invokeai.app.invocations.primitives import ImageOutput
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
|
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
img, img_ids = self._prepare_latent_img_patches(x)
|
img, img_ids = prepare_latent_img_patches(x)
|
||||||
|
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
@ -123,30 +123,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Convert an input image in latent space to patches for diffusion.
|
|
||||||
|
|
||||||
This implementation was extracted from:
|
|
||||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
|
||||||
"""
|
|
||||||
bs, c, h, w = latent_img.shape
|
|
||||||
|
|
||||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
|
||||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
if img.shape[0] == 1 and bs > 1:
|
|
||||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
|
||||||
|
|
||||||
# Generate patch position ids.
|
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
return img, img_ids
|
|
||||||
|
|
||||||
def _run_vae_decoding(
|
def _run_vae_decoding(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
|
@ -147,3 +147,28 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
|||||||
ph=2,
|
ph=2,
|
||||||
pw=2,
|
pw=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Convert an input image in latent space to patches for diffusion.
|
||||||
|
|
||||||
|
This implementation was extracted from:
|
||||||
|
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||||
|
"""
|
||||||
|
bs, c, h, w = latent_img.shape
|
||||||
|
|
||||||
|
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||||
|
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
if img.shape[0] == 1 and bs > 1:
|
||||||
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
# Generate patch position ids.
|
||||||
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
return img, img_ids
|
||||||
|
Loading…
Reference in New Issue
Block a user