Update installer

This commit is contained in:
Brandon Rising 2023-07-28 21:02:48 -04:00
parent d3f6c7f983
commit 1bbf2f269d
2 changed files with 10 additions and 3 deletions

View File

@ -451,7 +451,7 @@ def get_torch_source() -> (Union[str, None], str):
device = graphical_accelerator() device = graphical_accelerator()
url = None url = None
optional_modules = None optional_modules = "[onnx]"
if OS == "Linux": if OS == "Linux":
if device == "rocm": if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.4.2" url = "https://download.pytorch.org/whl/rocm5.4.2"
@ -460,7 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
if device == "cuda": if device == "cuda":
url = "https://download.pytorch.org/whl/cu117" url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers]" optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -167,6 +167,10 @@ def graphical_accelerator():
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)", "an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
"cuda", "cuda",
) )
nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX)",
"cuda_and_dml",
)
amd = ( amd = (
"an [gold1 b]AMD[/] GPU (using ROCm™)", "an [gold1 b]AMD[/] GPU (using ROCm™)",
"rocm", "rocm",
@ -181,7 +185,7 @@ def graphical_accelerator():
) )
if OS == "Windows": if OS == "Windows":
options = [nvidia, cpu] options = [nvidia, nvidia_with_dml, cpu]
if OS == "Linux": if OS == "Linux":
options = [nvidia, amd, cpu] options = [nvidia, amd, cpu]
elif OS == "Darwin": elif OS == "Darwin":