fix(TAESD): correct usage of singledispatchmethod so normal VAE still works

This commit is contained in:
Kevin Turner 2023-08-18 14:05:12 -07:00
parent 26a7b7b66d
commit 4f0e43ec1b

View File

@ -736,11 +736,12 @@ class ImageToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
def _encode_to_tensor(self, vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@singledispatchmethod
def _encode_to_tensor(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
@_encode_to_tensor.register
def _(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents