Multiple refinements on loaders:

- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
This commit is contained in:
Lincoln Stein
2024-02-05 21:55:11 -05:00
committed by psychedelicious
parent 0d3addc69b
commit 5745ce9c7d
18 changed files with 215 additions and 49 deletions

View File

@ -29,12 +29,17 @@ def choose_torch_device() -> torch.device:
return torch.device(config.device)
def choose_precision(device: torch.device) -> str:
"""Returns an appropriate precision for the given torch device"""
# We are in transition here from using a single global AppConfig to allowing multiple
# configurations. It is strongly recommended to pass the app_config to this function.
def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str:
"""Return an appropriate precision for the given torch device."""
app_config = app_config or config
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
if config.precision == "bfloat16":
if app_config.precision == "float32":
return "float32"
elif app_config.precision == "bfloat16":
return "bfloat16"
else:
return "float16"
@ -43,9 +48,14 @@ def choose_precision(device: torch.device) -> str:
return "float32"
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
# We are in transition here from using a single global AppConfig to allowing multiple
# configurations. It is strongly recommended to pass the app_config to this function.
def torch_dtype(
device: Optional[torch.device] = None,
app_config: Optional[InvokeAIAppConfig] = None,
) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device)
precision = choose_precision(device, app_config)
if precision == "float16":
return torch.float16
if precision == "bfloat16":