diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 35b8483f2c..90885d5c40 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -10,6 +10,8 @@ from diffusers.models.attention_processor import ( ) from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny +from einops import rearrange +from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR @@ -24,6 +26,9 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.modules.autoencoder import AutoEncoder +from invokeai.backend.model_manager.config import BaseModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice @@ -53,11 +58,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) - @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.load(self.latents.latents_name) - - vae_info = context.models.load(self.vae.vae) + def _vae_decode_stable_diffusion(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) @@ -113,6 +114,38 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = VaeImageProcessor.numpy_to_pil(np_image)[0] + return image + + def _vae_decode_flux(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: + with vae_info as vae: + assert isinstance(vae, AutoEncoder) + latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) + img = vae.decode(latents) + + img = img.clamp(-1, 1) + img = rearrange(img[0], "c h w -> h w c") # noqa: F821 + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) + return img_pil + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + + vae_info = context.models.load(self.vae.vae) + if vae_info.config.base == BaseModelType.Flux: + if self.fp32: + raise NotImplementedError("FLUX VAE decode does not currently support fp32=True.") + if self.tiled: + raise NotImplementedError("FLUX VAE decode does not currently support tiled=True.") + image = self._vae_decode_flux(vae_info=vae_info, latents=latents) + elif vae_info.config.base in [ + BaseModelType.StableDiffusion1, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusionXL, + ]: + image = self._vae_decode_stable_diffusion(vae_info=vae_info, latents=latents) + else: + raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'") TorchDevice.empty_cache()