Run FLUX VAE decoding in the user's preferred dtype rather than float32. Tested, and seems to work well at float16.

This commit is contained in:
Ryan Dick 2024-08-22 18:16:43 +00:00 committed by Brandon
parent 14ab339b33
commit a0bf20bcee
2 changed files with 2 additions and 4 deletions

View File

@ -131,10 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(self.vae.vae)
with vae_info as vae: with vae_info as vae:
assert isinstance(vae, AutoEncoder) assert isinstance(vae, AutoEncoder)
# TODO(ryand): Test that this works with both float16 and bfloat16. latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
vae.to(torch.float32)
latents.to(torch.float32)
img = vae.decode(latents) img = vae.decode(latents)
img = img.clamp(-1, 1) img = img.clamp(-1, 1)

View File

@ -65,6 +65,7 @@ class FluxVAELoader(ModelLoader):
model = AutoEncoder(params) model = AutoEncoder(params)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, assign=True) model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
return model return model