From 79640ba14e9d0d93b39dbbcd1312e2ebe7767eb6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 26 Jun 2024 11:38:02 -0400 Subject: [PATCH 1/3] Add context manager for overriding VAE tiling params. --- invokeai/app/invocations/image_to_latents.py | 13 +++++++-- invokeai/app/invocations/latents_to_image.py | 17 ++++++++--- .../backend/stable_diffusion/vae_tiling.py | 29 +++++++++++++++++++ 3 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/vae_tiling.py diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 06de530154..73e6106ff9 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from functools import singledispatchmethod import einops @@ -24,6 +25,7 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params @invocation( @@ -49,7 +51,7 @@ class ImageToLatentsInvocation(BaseInvocation): @staticmethod def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: with vae_info as vae: - assert isinstance(vae, torch.nn.Module) + assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) orig_dtype = vae.dtype if upcast: vae.to(dtype=torch.float32) @@ -76,14 +78,21 @@ class ImageToLatentsInvocation(BaseInvocation): vae.to(dtype=torch.float16) # latents = latents.half() + tiling_context = nullcontext() if tiled: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=512, + tile_latent_min_size=512 // 8, + tile_overlap_factor=0.25, + ) vae.enable_tiling() else: vae.disable_tiling() # non_noised_latents_from_image image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) - with torch.inference_mode(): + with torch.inference_mode(), tiling_context: latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor) latents = vae.config.scaling_factor * latents diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 202e8bfa1b..3d714730dc 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import torch from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import ( @@ -8,7 +10,6 @@ from diffusers.models.attention_processor import ( ) from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.constants import DEFAULT_PRECISION @@ -24,6 +25,7 @@ from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion import set_seamless +from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice @@ -53,9 +55,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): latents = context.tensors.load(self.latents.latents_name) vae_info = context.models.load(self.vae.vae) - assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) + assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.nn.Module) + assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -82,7 +84,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae.to(dtype=torch.float16) latents = latents.half() + tiling_context = nullcontext() if self.tiled or context.config.get().force_tiled_decode: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=512, + tile_latent_min_size=512 // 8, + tile_overlap_factor=0.25, + ) vae.enable_tiling() else: vae.disable_tiling() @@ -90,7 +99,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): # clear memory as vae decode can request a lot TorchDevice.empty_cache() - with torch.inference_mode(): + with torch.inference_mode(), tiling_context: # copied from diffusers pipeline latents = latents / vae.config.scaling_factor image = vae.decode(latents, return_dict=False)[0] diff --git a/invokeai/backend/stable_diffusion/vae_tiling.py b/invokeai/backend/stable_diffusion/vae_tiling.py new file mode 100644 index 0000000000..1fa7a18708 --- /dev/null +++ b/invokeai/backend/stable_diffusion/vae_tiling.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager + +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny + + +@contextmanager +def patch_vae_tiling_params( + vae: AutoencoderKL | AutoencoderTiny, + tile_sample_min_size: int, + tile_latent_min_size: int, + tile_overlap_factor: float, +): + # Record initial config. + orig_tile_sample_min_size = vae.tile_sample_min_size + orig_tile_latent_min_size = vae.tile_latent_min_size + orig_tile_overlap_factor = vae.tile_overlap_factor + + try: + # Apply target config. + vae.tile_sample_min_size = tile_sample_min_size + vae.tile_latent_min_size = tile_latent_min_size + vae.tile_overlap_factor = tile_overlap_factor + yield + finally: + # Restore initial config. + vae.tile_sample_min_size = orig_tile_sample_min_size + vae.tile_latent_min_size = orig_tile_latent_min_size + vae.tile_overlap_factor = orig_tile_overlap_factor From a1b7dbfa5442b9823a8d69a47f27be827880e0d5 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 28 Jun 2024 10:30:32 -0400 Subject: [PATCH 2/3] Add unit test for patch_vae_tiling_params(). --- tests/backend/stable_diffusion/test_vae_tiling.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 tests/backend/stable_diffusion/test_vae_tiling.py diff --git a/tests/backend/stable_diffusion/test_vae_tiling.py b/tests/backend/stable_diffusion/test_vae_tiling.py new file mode 100644 index 0000000000..4d97a8869b --- /dev/null +++ b/tests/backend/stable_diffusion/test_vae_tiling.py @@ -0,0 +1,13 @@ +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL + +from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params + + +def test_patch_vae_tiling_params(): + """Smoke test the patch_vae_tiling_params(...) context manager. The main purpose of this unit test is to detect if + diffusers ever changes the attributes of the AutoencoderKL class that we expect to exist. + """ + vae = AutoencoderKL() + + with patch_vae_tiling_params(vae, 1, 2, 3): + pass From 375250906661eda86fabfdeda50f0875f8fbc3d8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 28 Jun 2024 11:30:09 -0400 Subject: [PATCH 3/3] Expose the VAE tile_size on the VAE encode and decode invocations. --- invokeai/app/invocations/fields.py | 2 ++ invokeai/app/invocations/image_to_latents.py | 31 ++++++++++++------- invokeai/app/invocations/latents_to_image.py | 23 ++++++++------ .../backend/stable_diffusion/vae_tiling.py | 6 ++++ 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0fa0216f1c..b792453b47 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -160,6 +160,8 @@ class FieldDescriptions: fp32 = "Whether or not to use full float32 precision" precision = "Precision to use" tiled = "Processing using overlapping tiles (reduce memory consumption)" + vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the " + "model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage." detect_res = "Pixel resolution for detection" image_res = "Pixel resolution for output image" safe_mode = "Whether or not to use safe mode" diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 73e6106ff9..dadd8bb3a1 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( FieldDescriptions, ImageField, @@ -33,7 +33,7 @@ from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.0.2", + version="1.1.0", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -46,10 +46,15 @@ class ImageToLatentsInvocation(BaseInvocation): input=Input.Connection, ) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + # NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not + # offer a way to directly set None values. + tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) @staticmethod - def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: + def vae_encode( + vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 + ) -> torch.Tensor: with vae_info as vae: assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) orig_dtype = vae.dtype @@ -78,18 +83,20 @@ class ImageToLatentsInvocation(BaseInvocation): vae.to(dtype=torch.float16) # latents = latents.half() - tiling_context = nullcontext() if tiled: - tiling_context = patch_vae_tiling_params( - vae, - tile_sample_min_size=512, - tile_latent_min_size=512 // 8, - tile_overlap_factor=0.25, - ) vae.enable_tiling() else: vae.disable_tiling() + tiling_context = nullcontext() + if tile_size > 0: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=tile_size, + tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR, + tile_overlap_factor=0.25, + ) + # non_noised_latents_from_image image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) with torch.inference_mode(), tiling_context: @@ -110,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation): if image_tensor.dim() == 3: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) + latents = self.vae_encode( + vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size + ) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 3d714730dc..cc8a9c44a3 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.2.2", + version="1.3.0", ) class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Generates an image from latents.""" @@ -48,6 +48,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, ) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + # NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not + # offer a way to directly set None values. + tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) @torch.no_grad() @@ -84,18 +87,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae.to(dtype=torch.float16) latents = latents.half() - tiling_context = nullcontext() if self.tiled or context.config.get().force_tiled_decode: - tiling_context = patch_vae_tiling_params( - vae, - tile_sample_min_size=512, - tile_latent_min_size=512 // 8, - tile_overlap_factor=0.25, - ) vae.enable_tiling() else: vae.disable_tiling() + tiling_context = nullcontext() + if self.tile_size > 0: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=self.tile_size, + tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR, + tile_overlap_factor=0.25, + ) + # clear memory as vae decode can request a lot TorchDevice.empty_cache() diff --git a/invokeai/backend/stable_diffusion/vae_tiling.py b/invokeai/backend/stable_diffusion/vae_tiling.py index 1fa7a18708..d31cb331f4 100644 --- a/invokeai/backend/stable_diffusion/vae_tiling.py +++ b/invokeai/backend/stable_diffusion/vae_tiling.py @@ -11,6 +11,12 @@ def patch_vae_tiling_params( tile_latent_min_size: int, tile_overlap_factor: float, ): + """Patch the parameters that control the VAE tiling tile size and overlap. + + These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of + the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory + usage. + """ # Record initial config. orig_tile_sample_min_size = vae.tile_sample_min_size orig_tile_latent_min_size = vae.tile_latent_min_size