From 1bbf2f269d13d0c3777b819e417e34133a6647a0 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 28 Jul 2023 21:02:48 -0400 Subject: [PATCH] Update installer --- installer/lib/installer.py | 7 +++++-- installer/lib/messages.py | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/installer/lib/installer.py b/installer/lib/installer.py index e1ca8c2e8f..d7662b3bd4 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -451,7 +451,7 @@ def get_torch_source() -> (Union[str, None], str): device = graphical_accelerator() url = None - optional_modules = None + optional_modules = "[onnx]" if OS == "Linux": if device == "rocm": url = "https://download.pytorch.org/whl/rocm5.4.2" @@ -460,7 +460,10 @@ def get_torch_source() -> (Union[str, None], str): if device == "cuda": url = "https://download.pytorch.org/whl/cu117" - optional_modules = "[xformers]" + optional_modules = "[xformers,onnx-cuda]" + if device == "cuda_and_dml": + url = "https://download.pytorch.org/whl/cu117" + optional_modules = "[xformers,onnx-directml]" # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 diff --git a/installer/lib/messages.py b/installer/lib/messages.py index 3687b52d32..cc7c579216 100644 --- a/installer/lib/messages.py +++ b/installer/lib/messages.py @@ -167,6 +167,10 @@ def graphical_accelerator(): "an [gold1 b]NVIDIA[/] GPU (using CUDA™)", "cuda", ) + nvidia_with_dml = ( + "an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX)", + "cuda_and_dml", + ) amd = ( "an [gold1 b]AMD[/] GPU (using ROCm™)", "rocm", @@ -181,7 +185,7 @@ def graphical_accelerator(): ) if OS == "Windows": - options = [nvidia, cpu] + options = [nvidia, nvidia_with_dml, cpu] if OS == "Linux": options = [nvidia, amd, cpu] elif OS == "Darwin":