mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add FLUX VAE support to ImageToLatentsInvocation.
This commit is contained in:
parent
87261bdbc9
commit
21701173d8
@ -23,7 +23,9 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import LatentsOutput
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
from invokeai.backend.model_manager import LoadedModel
|
from invokeai.backend.model_manager import LoadedModel
|
||||||
|
from invokeai.backend.model_manager.config import BaseModelType
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
|
|
||||||
@ -52,7 +54,19 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def vae_encode(
|
def vae_encode_flux(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||||
|
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||||
|
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||||
|
# should be used for VAE encode sampling.
|
||||||
|
generator = torch.Generator().manual_seed(0)
|
||||||
|
with vae_info as vae:
|
||||||
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vae_encode_stable_diffusion(
|
||||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
@ -107,6 +121,27 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vae_encode(
|
||||||
|
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if vae_info.config.base == BaseModelType.Flux:
|
||||||
|
if upcast:
|
||||||
|
raise NotImplementedError("FLUX VAE encode does not currently support upcast=True.")
|
||||||
|
if tiled:
|
||||||
|
raise NotImplementedError("FLUX VAE encode does not currently support tiled=True.")
|
||||||
|
return ImageToLatentsInvocation.vae_encode_flux(vae_info=vae_info, image_tensor=image_tensor)
|
||||||
|
elif vae_info.config.base in [
|
||||||
|
BaseModelType.StableDiffusion1,
|
||||||
|
BaseModelType.StableDiffusion2,
|
||||||
|
BaseModelType.StableDiffusionXL,
|
||||||
|
]:
|
||||||
|
return ImageToLatentsInvocation.vae_encode_stable_diffusion(
|
||||||
|
vae_info=vae_info, upcast=upcast, tiled=tiled, image_tensor=image_tensor, tile_size=tile_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
@ -258,16 +258,17 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DiagonalGaussian(nn.Module):
|
class DiagonalGaussian(nn.Module):
|
||||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
def __init__(self, chunk_dim: int = 1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sample = sample
|
|
||||||
self.chunk_dim = chunk_dim
|
self.chunk_dim = chunk_dim
|
||||||
|
|
||||||
def forward(self, z: Tensor) -> Tensor:
|
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||||
if self.sample:
|
if sample:
|
||||||
std = torch.exp(0.5 * logvar)
|
std = torch.exp(0.5 * logvar)
|
||||||
return mean + std * torch.randn_like(mean)
|
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||||
|
# have to use torch.randn(...) instead.
|
||||||
|
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||||
else:
|
else:
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
@ -297,8 +298,21 @@ class AutoEncoder(nn.Module):
|
|||||||
self.scale_factor = params.scale_factor
|
self.scale_factor = params.scale_factor
|
||||||
self.shift_factor = params.shift_factor
|
self.shift_factor = params.shift_factor
|
||||||
|
|
||||||
def encode(self, x: Tensor) -> Tensor:
|
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||||
z = self.reg(self.encoder(x))
|
"""Run VAE encoding on input tensor x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||||
|
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||||
|
Defaults to True.
|
||||||
|
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||||
z = self.scale_factor * (z - self.shift_factor)
|
z = self.scale_factor * (z - self.shift_factor)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user