mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add FLUX VAE support to ImageToLatentsInvocation.
This commit is contained in:
parent
87261bdbc9
commit
21701173d8
@ -23,7 +23,9 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
|
||||
@ -52,7 +54,19 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(
|
||||
def vae_encode_flux(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def vae_encode_stable_diffusion(
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||
) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
@ -107,6 +121,27 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||
) -> torch.Tensor:
|
||||
if vae_info.config.base == BaseModelType.Flux:
|
||||
if upcast:
|
||||
raise NotImplementedError("FLUX VAE encode does not currently support upcast=True.")
|
||||
if tiled:
|
||||
raise NotImplementedError("FLUX VAE encode does not currently support tiled=True.")
|
||||
return ImageToLatentsInvocation.vae_encode_flux(vae_info=vae_info, image_tensor=image_tensor)
|
||||
elif vae_info.config.base in [
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
]:
|
||||
return ImageToLatentsInvocation.vae_encode_stable_diffusion(
|
||||
vae_info=vae_info, upcast=upcast, tiled=tiled, image_tensor=image_tensor, tile_size=tile_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
@ -258,16 +258,17 @@ class Decoder(nn.Module):
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
def __init__(self, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
if sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||
# have to use torch.randn(...) instead.
|
||||
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||
else:
|
||||
return mean
|
||||
|
||||
@ -297,8 +298,21 @@ class AutoEncoder(nn.Module):
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
"""Run VAE encoding on input tensor x.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||
Defaults to True.
|
||||
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||
"""
|
||||
|
||||
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user