Update installer

This commit is contained in:
Brandon Rising 2023-07-28 21:02:48 -04:00
parent d3f6c7f983
commit 1bbf2f269d
2 changed files with 10 additions and 3 deletions

View File

@ -451,7 +451,7 @@ def get_torch_source() -> (Union[str, None], str):
device = graphical_accelerator()
url = None
optional_modules = None
optional_modules = "[onnx]"
if OS == "Linux":
if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.4.2"
@ -460,7 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
if device == "cuda":
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers]"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -167,6 +167,10 @@ def graphical_accelerator():
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
"cuda",
)
nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX)",
"cuda_and_dml",
)
amd = (
"an [gold1 b]AMD[/] GPU (using ROCm™)",
"rocm",
@ -181,7 +185,7 @@ def graphical_accelerator():
)
if OS == "Windows":
options = [nvidia, cpu]
options = [nvidia, nvidia_with_dml, cpu]
if OS == "Linux":
options = [nvidia, amd, cpu]
elif OS == "Darwin":