mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Expose the VAE tile_size on the VAE encode and decode invocations.
This commit is contained in:
@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@ -33,7 +33,7 @@ from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.1.0",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@ -46,10 +46,15 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def vae_encode(
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||
) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
orig_dtype = vae.dtype
|
||||
@ -78,18 +83,20 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
vae.to(dtype=torch.float16)
|
||||
# latents = latents.half()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
if tiled:
|
||||
tiling_context = patch_vae_tiling_params(
|
||||
vae,
|
||||
tile_sample_min_size=512,
|
||||
tile_latent_min_size=512 // 8,
|
||||
tile_overlap_factor=0.25,
|
||||
)
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
if tile_size > 0:
|
||||
tiling_context = patch_vae_tiling_params(
|
||||
vae,
|
||||
tile_sample_min_size=tile_size,
|
||||
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
|
||||
tile_overlap_factor=0.25,
|
||||
)
|
||||
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode(), tiling_context:
|
||||
@ -110,7 +117,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
Reference in New Issue
Block a user