From 74e644c4ba327a6336176be5c31ab056b8f10fef Mon Sep 17 00:00:00 2001 From: Millun Atluri Date: Fri, 12 Jan 2024 13:40:37 -0500 Subject: [PATCH] Allow bfloat16 to be configurable in invoke.yaml (#5469) * feat: allow bfloat16 to be configurable in invoke.yaml * fix: `torch_dtype()` util - Use `choose_precision` to get the precision string - Do not reference deprecated `config.full_precision` flat (why does this still exist?), if a user had this enabled it would override their actual precision setting and potentially cause a lot of confusion. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> --- invokeai/app/services/config/config_default.py | 2 +- invokeai/backend/util/devices.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 180c4a4b3e..0a07e3e9d7 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -263,7 +263,7 @@ class InvokeAIAppConfig(InvokeAISettings): # DEVICE device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device) - precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device) + precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device) # GENERATION sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 5b5f616ca9..d6d3ad727f 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -34,18 +34,23 @@ def choose_precision(device: torch.device) -> str: if device.type == "cuda": device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): - return "float16" + if config.precision == "bfloat16": + return "bfloat16" + else: + return "float16" elif device.type == "mps": return "float16" return "float32" def torch_dtype(device: torch.device) -> torch.dtype: - if config.full_precision: - return torch.float32 - if choose_precision(device) == "float16": - return torch.bfloat16 if device.type == "cuda" else torch.float16 + precision = choose_precision(device) + if precision == "float16": + return torch.float16 + if precision == "bfloat16": + return torch.bfloat16 else: + # "auto", "autocast", "float32" return torch.float32