mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Multiple refinements on loaders:
- Cache stat collection enabled. - Implemented ONNX loading. - Add ability to specify the repo version variant in installer CLI. - If caller asks for a repo version that doesn't exist, will fall back to empty version rather than raising an error.
This commit is contained in:
committed by
psychedelicious
parent
0d3addc69b
commit
5745ce9c7d
@ -29,12 +29,17 @@ def choose_torch_device() -> torch.device:
|
||||
return torch.device(config.device)
|
||||
|
||||
|
||||
def choose_precision(device: torch.device) -> str:
|
||||
"""Returns an appropriate precision for the given torch device"""
|
||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
||||
def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str:
|
||||
"""Return an appropriate precision for the given torch device."""
|
||||
app_config = app_config or config
|
||||
if device.type == "cuda":
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
if config.precision == "bfloat16":
|
||||
if app_config.precision == "float32":
|
||||
return "float32"
|
||||
elif app_config.precision == "bfloat16":
|
||||
return "bfloat16"
|
||||
else:
|
||||
return "float16"
|
||||
@ -43,9 +48,14 @@ def choose_precision(device: torch.device) -> str:
|
||||
return "float32"
|
||||
|
||||
|
||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
||||
def torch_dtype(
|
||||
device: Optional[torch.device] = None,
|
||||
app_config: Optional[InvokeAIAppConfig] = None,
|
||||
) -> torch.dtype:
|
||||
device = device or choose_torch_device()
|
||||
precision = choose_precision(device)
|
||||
precision = choose_precision(device, app_config)
|
||||
if precision == "float16":
|
||||
return torch.float16
|
||||
if precision == "bfloat16":
|
||||
|
Reference in New Issue
Block a user