mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use torch.bfloat16 on cuda systems
This commit is contained in:
parent
59aa009c93
commit
6460dcc7e0
@ -44,7 +44,7 @@ def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
if config.full_precision:
|
||||
return torch.float32
|
||||
if choose_precision(device) == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16 if device.type == "cuda" else torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user