mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add context manager for overriding VAE tiling params.
This commit is contained in:
parent
4075a81676
commit
79640ba14e
@ -1,3 +1,4 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
@ -24,6 +25,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager import LoadedModel
|
from invokeai.backend.model_manager import LoadedModel
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -49,7 +51,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||||
orig_dtype = vae.dtype
|
orig_dtype = vae.dtype
|
||||||
if upcast:
|
if upcast:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
@ -76,14 +78,21 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
# latents = latents.half()
|
# latents = latents.half()
|
||||||
|
|
||||||
|
tiling_context = nullcontext()
|
||||||
if tiled:
|
if tiled:
|
||||||
|
tiling_context = patch_vae_tiling_params(
|
||||||
|
vae,
|
||||||
|
tile_sample_min_size=512,
|
||||||
|
tile_latent_min_size=512 // 8,
|
||||||
|
tile_overlap_factor=0.25,
|
||||||
|
)
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
|
||||||
# non_noised_latents_from_image
|
# non_noised_latents_from_image
|
||||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode(), tiling_context:
|
||||||
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||||
|
|
||||||
latents = vae.config.scaling_factor * latents
|
latents = vae.config.scaling_factor * latents
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
@ -8,7 +10,6 @@ from diffusers.models.attention_processor import (
|
|||||||
)
|
)
|
||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||||
@ -24,6 +25,7 @@ from invokeai.app.invocations.model import VAEField
|
|||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion import set_seamless
|
from invokeai.backend.stable_diffusion import set_seamless
|
||||||
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
@ -53,9 +55,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
@ -82,7 +84,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
latents = latents.half()
|
latents = latents.half()
|
||||||
|
|
||||||
|
tiling_context = nullcontext()
|
||||||
if self.tiled or context.config.get().force_tiled_decode:
|
if self.tiled or context.config.get().force_tiled_decode:
|
||||||
|
tiling_context = patch_vae_tiling_params(
|
||||||
|
vae,
|
||||||
|
tile_sample_min_size=512,
|
||||||
|
tile_latent_min_size=512 // 8,
|
||||||
|
tile_overlap_factor=0.25,
|
||||||
|
)
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
@ -90,7 +99,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode(), tiling_context:
|
||||||
# copied from diffusers pipeline
|
# copied from diffusers pipeline
|
||||||
latents = latents / vae.config.scaling_factor
|
latents = latents / vae.config.scaling_factor
|
||||||
image = vae.decode(latents, return_dict=False)[0]
|
image = vae.decode(latents, return_dict=False)[0]
|
||||||
|
29
invokeai/backend/stable_diffusion/vae_tiling.py
Normal file
29
invokeai/backend/stable_diffusion/vae_tiling.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_vae_tiling_params(
|
||||||
|
vae: AutoencoderKL | AutoencoderTiny,
|
||||||
|
tile_sample_min_size: int,
|
||||||
|
tile_latent_min_size: int,
|
||||||
|
tile_overlap_factor: float,
|
||||||
|
):
|
||||||
|
# Record initial config.
|
||||||
|
orig_tile_sample_min_size = vae.tile_sample_min_size
|
||||||
|
orig_tile_latent_min_size = vae.tile_latent_min_size
|
||||||
|
orig_tile_overlap_factor = vae.tile_overlap_factor
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply target config.
|
||||||
|
vae.tile_sample_min_size = tile_sample_min_size
|
||||||
|
vae.tile_latent_min_size = tile_latent_min_size
|
||||||
|
vae.tile_overlap_factor = tile_overlap_factor
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Restore initial config.
|
||||||
|
vae.tile_sample_min_size = orig_tile_sample_min_size
|
||||||
|
vae.tile_latent_min_size = orig_tile_latent_min_size
|
||||||
|
vae.tile_overlap_factor = orig_tile_overlap_factor
|
Loading…
Reference in New Issue
Block a user