use torch.bfloat16 on cuda systems

This commit is contained in:
Lincoln Stein 2024-01-04 20:42:29 -05:00 committed by Kent Keirsey
parent 59aa009c93
commit 6460dcc7e0

View File

@ -44,7 +44,7 @@ def torch_dtype(device: torch.device) -> torch.dtype:
if config.full_precision:
return torch.float32
if choose_precision(device) == "float16":
return torch.float16
return torch.bfloat16 if device.type == "cuda" else torch.float16
else:
return torch.float32