Add context manager for overriding VAE tiling params.

This commit is contained in:
Ryan Dick
2024-06-26 11:38:02 -04:00
parent 4075a81676
commit 79640ba14e
3 changed files with 53 additions and 6 deletions

View File

@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import singledispatchmethod
import einops
@ -24,6 +25,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation(
@ -49,7 +51,7 @@ class ImageToLatentsInvocation(BaseInvocation):
@staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
@ -76,14 +78,21 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.to(dtype=torch.float16)
# latents = latents.half()
tiling_context = nullcontext()
if tiled:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=512,
tile_latent_min_size=512 // 8,
tile_overlap_factor=0.25,
)
vae.enable_tiling()
else:
vae.disable_tiling()
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents