diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 2e80afc1e4..9504abee3e 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -131,10 +131,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae_info = context.models.load(self.vae.vae) with vae_info as vae: assert isinstance(vae, AutoEncoder) - # TODO(ryand): Test that this works with both float16 and bfloat16. - # with torch.autocast(device_type=latents.device.type, dtype=torch.float32): - vae.to(torch.float32) - latents.to(torch.float32) + latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) img = vae.decode(latents) img = img.clamp(-1, 1) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index bb57e4413c..e37b12c4f7 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -65,6 +65,7 @@ class FluxVAELoader(ModelLoader): model = AutoEncoder(params) sd = load_file(model_path) model.load_state_dict(sd, assign=True) + model.to(dtype=self._torch_dtype) return model