Add FLUX VAE decoding support to LatentsToImageInvocation.

This commit is contained in:
Ryan Dick 2024-08-28 20:30:00 +00:00
parent 21701173d8
commit e3a7bf12c1

View File

@ -10,6 +10,8 @@ from diffusers.models.attention_processor import (
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
@ -24,6 +26,9 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@ -53,11 +58,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
def _vae_decode_stable_diffusion(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
@ -113,6 +114,38 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
return image
def _vae_decode_flux(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
if vae_info.config.base == BaseModelType.Flux:
if self.fp32:
raise NotImplementedError("FLUX VAE decode does not currently support fp32=True.")
if self.tiled:
raise NotImplementedError("FLUX VAE decode does not currently support tiled=True.")
image = self._vae_decode_flux(vae_info=vae_info, latents=latents)
elif vae_info.config.base in [
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]:
image = self._vae_decode_stable_diffusion(vae_info=vae_info, latents=latents)
else:
raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'")
TorchDevice.empty_cache()