mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add context manager for overriding VAE tiling params.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from contextlib import nullcontext
|
||||
from functools import singledispatchmethod
|
||||
|
||||
import einops
|
||||
@ -24,6 +25,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -49,7 +51,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
@ -76,14 +78,21 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
vae.to(dtype=torch.float16)
|
||||
# latents = latents.half()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
if tiled:
|
||||
tiling_context = patch_vae_tiling_params(
|
||||
vae,
|
||||
tile_sample_min_size=512,
|
||||
tile_latent_min_size=512 // 8,
|
||||
tile_overlap_factor=0.25,
|
||||
)
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode(), tiling_context:
|
||||
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
|
Reference in New Issue
Block a user