diff --git a/invokeai/app/invocations/flux_vae_decode.py b/invokeai/app/invocations/flux_vae_decode.py new file mode 100644 index 0000000000..54d9b57f13 --- /dev/null +++ b/invokeai/app/invocations/flux_vae_decode.py @@ -0,0 +1,60 @@ +import torch +from einops import rearrange +from PIL import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.modules.autoencoder import AutoEncoder +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "flux_vae_decode", + title="FLUX VAE Decode", + tags=["latents", "image", "vae", "l2i", "flux"], + category="latents", + version="1.0.0", +) +class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates an image from latents.""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + + def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: + with vae_info as vae: + assert isinstance(vae, AutoEncoder) + latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()) + img = vae.decode(latents) + + img = img.clamp(-1, 1) + img = rearrange(img[0], "c h w -> h w c") # noqa: F821 + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) + return img_pil + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + vae_info = context.models.load(self.vae.vae) + image = self._vae_decode(vae_info=vae_info, latents=latents) + + TorchDevice.empty_cache() + image_dto = context.images.save(image=image) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 7c7fdf07c7..6d359e82fd 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -10,8 +10,6 @@ from diffusers.models.attention_processor import ( ) from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from einops import rearrange -from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR @@ -26,9 +24,6 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.flux.modules.autoencoder import AutoEncoder -from invokeai.backend.model_manager.config import BaseModelType -from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice @@ -39,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.4.0", + version="1.3.0", ) class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Generates an image from latents.""" @@ -58,7 +53,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) - def _vae_decode_stable_diffusion(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + + vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) @@ -114,38 +112,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = VaeImageProcessor.numpy_to_pil(np_image)[0] - return image - - def _vae_decode_flux(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image: - with vae_info as vae: - assert isinstance(vae, AutoEncoder) - latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()) - img = vae.decode(latents) - - img = img.clamp(-1, 1) - img = rearrange(img[0], "c h w -> h w c") # noqa: F821 - img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) - return img_pil - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.load(self.latents.latents_name) - - vae_info = context.models.load(self.vae.vae) - if vae_info.config.base == BaseModelType.Flux: - if self.fp32: - raise NotImplementedError("FLUX VAE decode does not currently support fp32=True.") - if self.tiled: - raise NotImplementedError("FLUX VAE decode does not currently support tiled=True.") - image = self._vae_decode_flux(vae_info=vae_info, latents=latents) - elif vae_info.config.base in [ - BaseModelType.StableDiffusion1, - BaseModelType.StableDiffusion2, - BaseModelType.StableDiffusionXL, - ]: - image = self._vae_decode_stable_diffusion(vae_info=vae_info, latents=latents) - else: - raise ValueError(f"Unsupported VAE base type: '{vae_info.config.base}'") TorchDevice.empty_cache()