Allow bfloat16 to be configurable in invoke.yaml (#5469)

* feat: allow bfloat16 to be configurable in invoke.yaml

* fix: `torch_dtype()` util

- Use `choose_precision` to get the precision string
- Do not reference deprecated `config.full_precision` flat (why does this still exist?), if a user had this enabled it would override their actual precision setting and potentially cause a lot of confusion.

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Millun Atluri
2024-01-12 13:40:37 -05:00
committed by GitHub
parent d4c36da3ee
commit 74e644c4ba
2 changed files with 11 additions and 6 deletions

View File

@ -34,18 +34,23 @@ def choose_precision(device: torch.device) -> str:
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16"
if config.precision == "bfloat16":
return "bfloat16"
else:
return "float16"
elif device.type == "mps":
return "float16"
return "float32"
def torch_dtype(device: torch.device) -> torch.dtype:
if config.full_precision:
return torch.float32
if choose_precision(device) == "float16":
return torch.bfloat16 if device.type == "cuda" else torch.float16
precision = choose_precision(device)
if precision == "float16":
return torch.float16
if precision == "bfloat16":
return torch.bfloat16
else:
# "auto", "autocast", "float32"
return torch.float32