mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run FLUX VAE decoding in the user's preferred dtype rather than float32. Tested, and seems to work well at float16.
This commit is contained in:
parent
14ab339b33
commit
a0bf20bcee
@ -131,10 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
assert isinstance(vae, AutoEncoder)
|
assert isinstance(vae, AutoEncoder)
|
||||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
||||||
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
|
|
||||||
vae.to(torch.float32)
|
|
||||||
latents.to(torch.float32)
|
|
||||||
img = vae.decode(latents)
|
img = vae.decode(latents)
|
||||||
|
|
||||||
img = img.clamp(-1, 1)
|
img = img.clamp(-1, 1)
|
||||||
|
@ -65,6 +65,7 @@ class FluxVAELoader(ModelLoader):
|
|||||||
model = AutoEncoder(params)
|
model = AutoEncoder(params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
|
model.to(dtype=self._torch_dtype)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user