From 375250906661eda86fabfdeda50f0875f8fbc3d8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 28 Jun 2024 11:30:09 -0400 Subject: [PATCH] Expose the VAE tile_size on the VAE encode and decode invocations. --- invokeai/app/invocations/fields.py | 2 ++ invokeai/app/invocations/image_to_latents.py | 31 ++++++++++++------- invokeai/app/invocations/latents_to_image.py | 23 ++++++++------ .../backend/stable_diffusion/vae_tiling.py | 6 ++++ 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0fa0216f1c..b792453b47 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -160,6 +160,8 @@ class FieldDescriptions: fp32 = "Whether or not to use full float32 precision" precision = "Precision to use" tiled = "Processing using overlapping tiles (reduce memory consumption)" + vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the " + "model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage." detect_res = "Pixel resolution for detection" image_res = "Pixel resolution for output image" safe_mode = "Whether or not to use safe mode" diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 73e6106ff9..dadd8bb3a1 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( FieldDescriptions, ImageField, @@ -33,7 +33,7 @@ from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.0.2", + version="1.1.0", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -46,10 +46,15 @@ class ImageToLatentsInvocation(BaseInvocation): input=Input.Connection, ) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + # NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not + # offer a way to directly set None values. + tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) @staticmethod - def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: + def vae_encode( + vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 + ) -> torch.Tensor: with vae_info as vae: assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) orig_dtype = vae.dtype @@ -78,18 +83,20 @@ class ImageToLatentsInvocation(BaseInvocation): vae.to(dtype=torch.float16) # latents = latents.half() - tiling_context = nullcontext() if tiled: - tiling_context = patch_vae_tiling_params( - vae, - tile_sample_min_size=512, - tile_latent_min_size=512 // 8, - tile_overlap_factor=0.25, - ) vae.enable_tiling() else: vae.disable_tiling() + tiling_context = nullcontext() + if tile_size > 0: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=tile_size, + tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR, + tile_overlap_factor=0.25, + ) + # non_noised_latents_from_image image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) with torch.inference_mode(), tiling_context: @@ -110,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation): if image_tensor.dim() == 3: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) + latents = self.vae_encode( + vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size + ) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 3d714730dc..cc8a9c44a3 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import DEFAULT_PRECISION +from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.2.2", + version="1.3.0", ) class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Generates an image from latents.""" @@ -48,6 +48,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, ) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + # NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not + # offer a way to directly set None values. + tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32) @torch.no_grad() @@ -84,18 +87,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae.to(dtype=torch.float16) latents = latents.half() - tiling_context = nullcontext() if self.tiled or context.config.get().force_tiled_decode: - tiling_context = patch_vae_tiling_params( - vae, - tile_sample_min_size=512, - tile_latent_min_size=512 // 8, - tile_overlap_factor=0.25, - ) vae.enable_tiling() else: vae.disable_tiling() + tiling_context = nullcontext() + if self.tile_size > 0: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=self.tile_size, + tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR, + tile_overlap_factor=0.25, + ) + # clear memory as vae decode can request a lot TorchDevice.empty_cache() diff --git a/invokeai/backend/stable_diffusion/vae_tiling.py b/invokeai/backend/stable_diffusion/vae_tiling.py index 1fa7a18708..d31cb331f4 100644 --- a/invokeai/backend/stable_diffusion/vae_tiling.py +++ b/invokeai/backend/stable_diffusion/vae_tiling.py @@ -11,6 +11,12 @@ def patch_vae_tiling_params( tile_latent_min_size: int, tile_overlap_factor: float, ): + """Patch the parameters that control the VAE tiling tile size and overlap. + + These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of + the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory + usage. + """ # Record initial config. orig_tile_sample_min_size = vae.tile_sample_min_size orig_tile_latent_min_size = vae.tile_latent_min_size