Minor improvements to LatentsToImageInvocation type hints.

This commit is contained in:
Ryan Dick 2024-06-07 11:45:42 -04:00
parent da066979cf
commit bb5648983f

View File

@ -8,7 +8,6 @@ from diffusers.models.attention_processor import (
) )
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -59,9 +58,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
use_fp32: bool, use_fp32: bool,
use_tiling: bool, use_tiling: bool,
) -> Image.Image: ) -> Image.Image:
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, seamless_axes), vae_info as vae: with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device) latents = latents.to(vae.device)
if use_fp32: if use_fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)