From a0bf20bcee8f4e41fcb9aac5b069cd229f79745c Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 22 Aug 2024 18:16:43 +0000 Subject: [PATCH] Run FLUX VAE decoding in the user's preferred dtype rather than float32. Tested, and seems to work well at float16. --- invokeai/app/invocations/flux_text_to_image.py | 5 +---- invokeai/backend/model_manager/load/model_loaders/flux.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 2e80afc1e4..9504abee3e 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -131,10 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae_info = context.models.load(self.vae.vae) with vae_info as vae: assert isinstance(vae, AutoEncoder) - # TODO(ryand): Test that this works with both float16 and bfloat16. - # with torch.autocast(device_type=latents.device.type, dtype=torch.float32): - vae.to(torch.float32) - latents.to(torch.float32) + latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) img = vae.decode(latents) img = img.clamp(-1, 1) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index bb57e4413c..e37b12c4f7 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -65,6 +65,7 @@ class FluxVAELoader(ModelLoader): model = AutoEncoder(params) sd = load_file(model_path) model.load_state_dict(sd, assign=True) + model.to(dtype=self._torch_dtype) return model