fix(installer) slightly better typing for GPU selection

This commit is contained in:
Eugene Brodsky 2024-02-05 11:51:01 -05:00 committed by Kent Keirsey
parent ca2bb6f0cc
commit 29bcc4b595
2 changed files with 34 additions and 29 deletions

View File

@ -368,25 +368,26 @@ def get_torch_source() -> Tuple[str | None, str | None]:
:rtype: list :rtype: list
""" """
from messages import graphical_accelerator from messages import select_gpu
# device can be one of: "cuda", "rocm", "cpu", "idk" # device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
device = graphical_accelerator() device = select_gpu()
url = None url = None
optional_modules = "[onnx]" optional_modules = "[onnx]"
if OS == "Linux": if OS == "Linux":
if device == "rocm": if device.value == "rocm":
url = "https://download.pytorch.org/whl/rocm5.6" url = "https://download.pytorch.org/whl/rocm5.6"
elif device == "cpu": elif device.value == "cpu":
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
if device == "cuda": elif OS == "Windows":
url = "https://download.pytorch.org/whl/cu121" if device.value == "cuda":
optional_modules = "[xformers,onnx-cuda]" url = "https://download.pytorch.org/whl/cu121"
if device == "cuda_and_dml": optional_modules = "[xformers,onnx-cuda]"
url = "https://download.pytorch.org/whl/cu121" if device.value == "cuda_and_dml":
optional_modules = "[xformers,onnx-directml]" url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -5,6 +5,7 @@ Installer user interaction
import os import os
import platform import platform
from enum import Enum
from pathlib import Path from pathlib import Path
from prompt_toolkit import HTML, prompt from prompt_toolkit import HTML, prompt
@ -182,39 +183,42 @@ def dest_path(dest=None) -> Path | None:
console.rule("Goodbye!") console.rule("Goodbye!")
def graphical_accelerator(): class GpuType(Enum):
CUDA = "cuda"
CUDA_AND_DML = "cuda_and_dml"
ROCM = "rocm"
CPU = "cpu"
AUTODETECT = "autodetect"
def select_gpu() -> GpuType:
""" """
Prompt the user to select the graphical accelerator in their system Prompt the user to select the GPU driver
This does not validate user's choices (yet), but only offers choices
valid for the platform.
CUDA is the fallback.
We may be able to detect the GPU driver by shelling out to `modprobe` or `lspci`,
but this is not yet supported or reliable. Also, some users may have exotic preferences.
""" """
if ARCH == "arm64" and OS != "Darwin": if ARCH == "arm64" and OS != "Darwin":
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.") print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
return "cpu" return GpuType.CPU
nvidia = ( nvidia = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)", "an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
"cuda", GpuType.CUDA,
) )
nvidia_with_dml = ( nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA", "an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
"cuda_and_dml", GpuType.CUDA_AND_DML,
) )
amd = ( amd = (
"an [gold1 b]AMD[/] GPU (using ROCm™)", "an [gold1 b]AMD[/] GPU (using ROCm™)",
"rocm", GpuType.ROCM,
) )
cpu = ( cpu = (
"no compatible GPU, or specifically prefer to use the CPU", "Do not install any GPU support, use CPU for generation (slow)",
"cpu", GpuType.CPU,
) )
idk = ( autodetect = (
"I'm not sure what to choose", "I'm not sure what to choose",
"idk", GpuType.AUTODETECT,
) )
options = [] options = []
@ -231,7 +235,7 @@ def graphical_accelerator():
return options[0][1] return options[0][1]
# "I don't know" is always added the last option # "I don't know" is always added the last option
options.append(idk) # type: ignore options.append(autodetect) # type: ignore
options = {str(i): opt for i, opt in enumerate(options, 1)} options = {str(i): opt for i, opt in enumerate(options, 1)}
@ -266,9 +270,9 @@ def graphical_accelerator():
), ),
) )
if options[choice][1] == "idk": if options[choice][1] is GpuType.AUTODETECT:
console.print( console.print(
"No problem. We will try to install a version that [i]should[/i] be compatible. :crossed_fingers:" "No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types."
) )
return options[choice][1] return options[choice][1]