Add FLUX VAE support to ImageToLatentsInvocation.

This commit is contained in:
Ryan Dick 2024-08-28 19:54:02 +00:00
parent 87261bdbc9
commit 21701173d8
2 changed files with 57 additions and 8 deletions

View File

@ -23,7 +23,9 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@ -52,7 +54,19 @@ class ImageToLatentsInvocation(BaseInvocation):
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(
def vae_encode_flux(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Expose seed parameter at the invocation level.
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
generator = torch.Generator().manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents
@staticmethod
def vae_encode_stable_diffusion(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
with vae_info as vae:
@ -107,6 +121,27 @@ class ImageToLatentsInvocation(BaseInvocation):
return latents
@staticmethod
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
if vae_info.config.base == BaseModelType.Flux:
if upcast:
raise NotImplementedError("FLUX VAE encode does not currently support upcast=True.")
if tiled:
raise NotImplementedError("FLUX VAE encode does not currently support tiled=True.")
return ImageToLatentsInvocation.vae_encode_flux(vae_info=vae_info, image_tensor=image_tensor)
elif vae_info.config.base in [
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]:
return ImageToLatentsInvocation.vae_encode_stable_diffusion(
vae_info=vae_info, upcast=upcast, tiled=tiled, image_tensor=image_tensor, tile_size=tile_size
)
else:
raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)

View File

@ -258,16 +258,17 @@ class Decoder(nn.Module):
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
def __init__(self, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
if sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
# have to use torch.randn(...) instead.
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
else:
return mean
@ -297,8 +298,21 @@ class AutoEncoder(nn.Module):
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
"""Run VAE encoding on input tensor x.
Args:
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
Defaults to True.
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
Defaults to None.
Returns:
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
"""
z = self.reg(self.encoder(x), sample=sample, generator=generator)
z = self.scale_factor * (z - self.shift_factor)
return z