From 21701173d85507f9906e7a59e8c455a77096a7b7 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 28 Aug 2024 19:54:02 +0000 Subject: [PATCH] Add FLUX VAE support to ImageToLatentsInvocation. --- invokeai/app/invocations/image_to_latents.py | 37 +++++++++++++++++++- invokeai/backend/flux/modules/autoencoder.py | 28 +++++++++++---- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index dadd8bb3a1..e277173b70 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -23,7 +23,9 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.model_manager.config import BaseModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params @@ -52,7 +54,19 @@ class ImageToLatentsInvocation(BaseInvocation): fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) @staticmethod - def vae_encode( + def vae_encode_flux(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + # TODO(ryand): Expose seed parameter at the invocation level. + # TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes. + # There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function + # should be used for VAE encode sampling. + generator = torch.Generator().manual_seed(0) + with vae_info as vae: + assert isinstance(vae, AutoEncoder) + latents = vae.encode(image_tensor, sample=True, generator=generator) + return latents + + @staticmethod + def vae_encode_stable_diffusion( vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 ) -> torch.Tensor: with vae_info as vae: @@ -107,6 +121,27 @@ class ImageToLatentsInvocation(BaseInvocation): return latents + @staticmethod + def vae_encode( + vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 + ) -> torch.Tensor: + if vae_info.config.base == BaseModelType.Flux: + if upcast: + raise NotImplementedError("FLUX VAE encode does not currently support upcast=True.") + if tiled: + raise NotImplementedError("FLUX VAE encode does not currently support tiled=True.") + return ImageToLatentsInvocation.vae_encode_flux(vae_info=vae_info, image_tensor=image_tensor) + elif vae_info.config.base in [ + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, + ]: + return ImageToLatentsInvocation.vae_encode_stable_diffusion( + vae_info=vae_info, upcast=upcast, tiled=tiled, image_tensor=image_tensor, tile_size=tile_size + ) + else: + raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'") + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) diff --git a/invokeai/backend/flux/modules/autoencoder.py b/invokeai/backend/flux/modules/autoencoder.py index 237769aba7..6b072a82f6 100644 --- a/invokeai/backend/flux/modules/autoencoder.py +++ b/invokeai/backend/flux/modules/autoencoder.py @@ -258,16 +258,17 @@ class Decoder(nn.Module): class DiagonalGaussian(nn.Module): - def __init__(self, sample: bool = True, chunk_dim: int = 1): + def __init__(self, chunk_dim: int = 1): super().__init__() - self.sample = sample self.chunk_dim = chunk_dim - def forward(self, z: Tensor) -> Tensor: + def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor: mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) - if self.sample: + if sample: std = torch.exp(0.5 * logvar) - return mean + std * torch.randn_like(mean) + # Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we + # have to use torch.randn(...) instead. + return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device) else: return mean @@ -297,8 +298,21 @@ class AutoEncoder(nn.Module): self.scale_factor = params.scale_factor self.shift_factor = params.shift_factor - def encode(self, x: Tensor) -> Tensor: - z = self.reg(self.encoder(x)) + def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor: + """Run VAE encoding on input tensor x. + + Args: + x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width). + sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean. + Defaults to True. + generator (torch.Generator | None, optional): Optional random number generator for reproducibility. + Defaults to None. + + Returns: + Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width). + """ + + z = self.reg(self.encoder(x), sample=sample, generator=generator) z = self.scale_factor * (z - self.shift_factor) return z