Make the VAE tile size configurable for tiled VAE (#6555)

## Summary

- This PR exposes a `tile_size` field on `ImageToLatentsInvocation` and
`LatentsToImageInvocation`.
  - Setting `tile_size = 0` preserves the default behaviour.
- This feature is primarily intended to support upscaling workflows that
require VAE encoding/decoding high resolution images. In the future, we
may want to expose the tile size as a global application config, but
that's a separate conversation.
- As a general rule, larger tile sizes produce better results at the
cost of higher memory usage.

### Example:

Original (5472x5472)

![orig](https://github.com/invoke-ai/InvokeAI/assets/14897797/af0a975d-11ed-4f3c-9e53-84f3da6c997e)

VAE roundtrip with 512x512 tiles (note the discoloration)

![vae_roundtrip_512x512](https://github.com/invoke-ai/InvokeAI/assets/14897797/d589ae3e-fe93-410a-904c-f61f0fc0f1f2)

VAE roundtrip with 1024x1024 tiles (some discoloration still present,
but less severe than at 512x512)

![vae_roundtrip_1024x1024](https://github.com/invoke-ai/InvokeAI/assets/14897797/d0bb9752-3bfa-444f-88c9-39a3ca89c748)


## Related Issues / Discussions

Related: #6144 

## QA Instructions

- [x] Test image generation via the Linear tab
- [x] Test VAE roundtrip with tiling disabled
- [x] Test VAE roundtrip with tiling and tile_size = 0
- [x] Test VAE roundtrip with tiling and tile_size > 0

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-07-02 09:16:07 -04:00 committed by GitHub
commit e9936c27fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 94 additions and 12 deletions

View File

@ -160,6 +160,8 @@ class FieldDescriptions:
fp32 = "Whether or not to use full float32 precision" fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use" precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)" tiled = "Processing using overlapping tiles (reduce memory consumption)"
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the "
"model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
detect_res = "Pixel resolution for detection" detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image" image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode" safe_mode = "Whether or not to use safe mode"

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import singledispatchmethod from functools import singledispatchmethod
import einops import einops
@ -12,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
@ -24,6 +25,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation( @invocation(
@ -31,7 +33,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
title="Image to Latents", title="Image to Latents",
tags=["latents", "image", "vae", "i2l"], tags=["latents", "image", "vae", "i2l"],
category="latents", category="latents",
version="1.0.2", version="1.1.0",
) )
class ImageToLatentsInvocation(BaseInvocation): class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents.""" """Encodes an image into latents."""
@ -44,12 +46,17 @@ class ImageToLatentsInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
) )
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod @staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
with vae_info as vae: with vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype orig_dtype = vae.dtype
if upcast: if upcast:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
@ -81,9 +88,18 @@ class ImageToLatentsInvocation(BaseInvocation):
else: else:
vae.disable_tiling() vae.disable_tiling()
tiling_context = nullcontext()
if tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=tile_size,
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# non_noised_latents_from_image # non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(): with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor) latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents latents = vae.config.scaling_factor * latents
@ -101,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation):
if image_tensor.dim() == 3: if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu") latents = latents.to("cpu")
name = context.tensors.save(tensor=latents) name = context.tensors.save(tensor=latents)

View File

@ -1,3 +1,5 @@
from contextlib import nullcontext
import torch import torch
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -8,10 +10,9 @@ from diffusers.models.attention_processor import (
) )
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
Input, Input,
@ -24,6 +25,7 @@ from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -32,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
title="Latents to Image", title="Latents to Image",
tags=["latents", "image", "vae", "l2i"], tags=["latents", "image", "vae", "l2i"],
category="latents", category="latents",
version="1.2.2", version="1.3.0",
) )
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents.""" """Generates an image from latents."""
@ -46,6 +48,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection, input=Input.Connection,
) )
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad() @torch.no_grad()
@ -53,9 +58,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents = context.tensors.load(self.latents.latents_name) latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
@ -87,10 +92,19 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
else: else:
vae.disable_tiling() vae.disable_tiling()
tiling_context = nullcontext()
if self.tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=self.tile_size,
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# clear memory as vae decode can request a lot # clear memory as vae decode can request a lot
TorchDevice.empty_cache() TorchDevice.empty_cache()
with torch.inference_mode(): with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline # copied from diffusers pipeline
latents = latents / vae.config.scaling_factor latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0] image = vae.decode(latents, return_dict=False)[0]

View File

@ -0,0 +1,35 @@
from contextlib import contextmanager
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
@contextmanager
def patch_vae_tiling_params(
vae: AutoencoderKL | AutoencoderTiny,
tile_sample_min_size: int,
tile_latent_min_size: int,
tile_overlap_factor: float,
):
"""Patch the parameters that control the VAE tiling tile size and overlap.
These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of
the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory
usage.
"""
# Record initial config.
orig_tile_sample_min_size = vae.tile_sample_min_size
orig_tile_latent_min_size = vae.tile_latent_min_size
orig_tile_overlap_factor = vae.tile_overlap_factor
try:
# Apply target config.
vae.tile_sample_min_size = tile_sample_min_size
vae.tile_latent_min_size = tile_latent_min_size
vae.tile_overlap_factor = tile_overlap_factor
yield
finally:
# Restore initial config.
vae.tile_sample_min_size = orig_tile_sample_min_size
vae.tile_latent_min_size = orig_tile_latent_min_size
vae.tile_overlap_factor = orig_tile_overlap_factor

View File

@ -0,0 +1,13 @@
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
def test_patch_vae_tiling_params():
"""Smoke test the patch_vae_tiling_params(...) context manager. The main purpose of this unit test is to detect if
diffusers ever changes the attributes of the AutoencoderKL class that we expect to exist.
"""
vae = AutoencoderKL()
with patch_vae_tiling_params(vae, 1, 2, 3):
pass