From b74bc7734760e2b06acf0aff57ba5bdbde526f78 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 26 Jun 2024 17:18:57 -0400 Subject: [PATCH] Revert "Revert "Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation."" This reverts commit 7cafd78d6e4b5ce2ff9bfbfce355f79d914e1c34. --- invokeai/app/invocations/latents_to_image.py | 42 +++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 202e8bfa1b..8134a9bfeb 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -8,7 +8,7 @@ from diffusers.models.attention_processor import ( ) from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.constants import DEFAULT_PRECISION @@ -23,6 +23,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion import set_seamless from invokeai.backend.util.devices import TorchDevice @@ -48,16 +49,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) - @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.load(self.latents.latents_name) - - vae_info = context.models.load(self.vae.vae) - assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) - with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.nn.Module) + @staticmethod + def vae_decode( + context: InvocationContext, + vae_info: LoadedModel, + seamless_axes: list[str], + latents: torch.Tensor, + use_fp32: bool, + use_tiling: bool, + ) -> Image.Image: + assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) + with set_seamless(vae_info.model, seamless_axes), vae_info as vae: + assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) latents = latents.to(vae.device) - if self.fp32: + if use_fp32: vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( @@ -82,7 +87,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.config.get().force_tiled_decode: + if use_tiling or context.config.get().force_tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -102,6 +107,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): TorchDevice.empty_cache() + return image + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + vae_info = context.models.load(self.vae.vae) + + image = self.vae_decode( + context=context, + vae_info=vae_info, + seamless_axes=self.vae.seamless_axes, + latents=latents, + use_fp32=self.fp32, + use_tiling=self.tiled, + ) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto)