Expose the VAE tile_size on the VAE encode and decode invocations.

This commit is contained in:
Ryan Dick 2024-06-28 11:30:09 -04:00
parent a1b7dbfa54
commit 3752509066
4 changed files with 42 additions and 20 deletions

View File

@ -160,6 +160,8 @@ class FieldDescriptions:
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the "
"model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"

View File

@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -33,7 +33,7 @@ from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.0.2",
version="1.1.0",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
@ -46,10 +46,15 @@ class ImageToLatentsInvocation(BaseInvocation):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
@ -78,18 +83,20 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.to(dtype=torch.float16)
# latents = latents.half()
tiling_context = nullcontext()
if tiled:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=512,
tile_latent_min_size=512 // 8,
tile_overlap_factor=0.25,
)
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=tile_size,
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
@ -110,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation):
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)

View File

@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.2.2",
version="1.3.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@ -48,6 +48,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
@ -84,18 +87,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
tiling_context = nullcontext()
if self.tiled or context.config.get().force_tiled_decode:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=512,
tile_latent_min_size=512 // 8,
tile_overlap_factor=0.25,
)
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if self.tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=self.tile_size,
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()

View File

@ -11,6 +11,12 @@ def patch_vae_tiling_params(
tile_latent_min_size: int,
tile_overlap_factor: float,
):
"""Patch the parameters that control the VAE tiling tile size and overlap.
These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of
the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory
usage.
"""
# Record initial config.
orig_tile_sample_min_size = vae.tile_sample_min_size
orig_tile_latent_min_size = vae.tile_latent_min_size