mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow bfloat16 to be configurable in invoke.yaml (#5469)
* feat: allow bfloat16 to be configurable in invoke.yaml * fix: `torch_dtype()` util - Use `choose_precision` to get the precision string - Do not reference deprecated `config.full_precision` flat (why does this still exist?), if a user had this enabled it would override their actual precision setting and potentially cause a lot of confusion. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
d4c36da3ee
commit
74e644c4ba
@ -263,7 +263,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
|
|
||||||
# DEVICE
|
# DEVICE
|
||||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||||
|
|
||||||
# GENERATION
|
# GENERATION
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||||
|
@ -34,6 +34,9 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||||
|
if config.precision == "bfloat16":
|
||||||
|
return "bfloat16"
|
||||||
|
else:
|
||||||
return "float16"
|
return "float16"
|
||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
return "float16"
|
return "float16"
|
||||||
@ -41,11 +44,13 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
if config.full_precision:
|
precision = choose_precision(device)
|
||||||
return torch.float32
|
if precision == "float16":
|
||||||
if choose_precision(device) == "float16":
|
return torch.float16
|
||||||
return torch.bfloat16 if device.type == "cuda" else torch.float16
|
if precision == "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
else:
|
else:
|
||||||
|
# "auto", "autocast", "float32"
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user