Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation.

This commit is contained in:
Ryan Dick 2024-06-07 11:41:39 -04:00 committed by Kent Keirsey
parent 020e8eb413
commit 21d7ca45e6

View File

@ -9,6 +9,7 @@ from diffusers.models.attention_processor import (
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION
@ -23,6 +24,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion import set_seamless from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -48,16 +50,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad() @staticmethod
def invoke(self, context: InvocationContext) -> ImageOutput: def vae_decode(
latents = context.tensors.load(self.latents.latents_name) context: InvocationContext,
vae_info: LoadedModel,
vae_info = context.models.load(self.vae.vae) seamless_axes: list[str],
latents: torch.Tensor,
use_fp32: bool,
use_tiling: bool,
) -> Image.Image:
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if use_fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
@ -82,7 +88,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
latents = latents.half() latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode: if use_tiling or context.config.get().force_tiled_decode:
vae.enable_tiling() vae.enable_tiling()
else: else:
vae.disable_tiling() vae.disable_tiling()
@ -102,6 +108,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
TorchDevice.empty_cache() TorchDevice.empty_cache()
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=latents,
use_fp32=self.fp32,
use_tiling=self.tiled,
)
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)