Run FLUX VAE decoding in the user's preferred dtype rather than float32. Tested, and seems to work well at float16.

This commit is contained in:
Ryan Dick
2024-08-22 18:16:43 +00:00
committed by Brandon
parent 14ab339b33
commit a0bf20bcee
2 changed files with 2 additions and 4 deletions

View File

@ -65,6 +65,7 @@ class FluxVAELoader(ModelLoader):
model = AutoEncoder(params)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
return model