Add context manager for overriding VAE tiling params.

This commit is contained in:
Ryan Dick 2024-06-26 11:38:02 -04:00
parent 4075a81676
commit 79640ba14e
3 changed files with 53 additions and 6 deletions

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import singledispatchmethod from functools import singledispatchmethod
import einops import einops
@ -24,6 +25,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation( @invocation(
@ -49,7 +51,7 @@ class ImageToLatentsInvocation(BaseInvocation):
@staticmethod @staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae: with vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype orig_dtype = vae.dtype
if upcast: if upcast:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
@ -76,14 +78,21 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
# latents = latents.half() # latents = latents.half()
tiling_context = nullcontext()
if tiled: if tiled:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=512,
tile_latent_min_size=512 // 8,
tile_overlap_factor=0.25,
)
vae.enable_tiling() vae.enable_tiling()
else: else:
vae.disable_tiling() vae.disable_tiling()
# non_noised_latents_from_image # non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(): with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor) latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents latents = vae.config.scaling_factor * latents

View File

@ -1,3 +1,5 @@
from contextlib import nullcontext
import torch import torch
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -8,7 +10,6 @@ from diffusers.models.attention_processor import (
) )
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION
@ -24,6 +25,7 @@ from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -53,9 +55,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents = context.tensors.load(self.latents.latents_name) latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
@ -82,7 +84,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
latents = latents.half() latents = latents.half()
tiling_context = nullcontext()
if self.tiled or context.config.get().force_tiled_decode: if self.tiled or context.config.get().force_tiled_decode:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=512,
tile_latent_min_size=512 // 8,
tile_overlap_factor=0.25,
)
vae.enable_tiling() vae.enable_tiling()
else: else:
vae.disable_tiling() vae.disable_tiling()
@ -90,7 +99,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# clear memory as vae decode can request a lot # clear memory as vae decode can request a lot
TorchDevice.empty_cache() TorchDevice.empty_cache()
with torch.inference_mode(): with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline # copied from diffusers pipeline
latents = latents / vae.config.scaling_factor latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0] image = vae.decode(latents, return_dict=False)[0]

View File

@ -0,0 +1,29 @@
from contextlib import contextmanager
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
@contextmanager
def patch_vae_tiling_params(
vae: AutoencoderKL | AutoencoderTiny,
tile_sample_min_size: int,
tile_latent_min_size: int,
tile_overlap_factor: float,
):
# Record initial config.
orig_tile_sample_min_size = vae.tile_sample_min_size
orig_tile_latent_min_size = vae.tile_latent_min_size
orig_tile_overlap_factor = vae.tile_overlap_factor
try:
# Apply target config.
vae.tile_sample_min_size = tile_sample_min_size
vae.tile_latent_min_size = tile_latent_min_size
vae.tile_overlap_factor = tile_overlap_factor
yield
finally:
# Restore initial config.
vae.tile_sample_min_size = orig_tile_sample_min_size
vae.tile_latent_min_size = orig_tile_latent_min_size
vae.tile_overlap_factor = orig_tile_overlap_factor