From 29bcc4b59524fb6c7fa39317e2656e8b95fe3488 Mon Sep 17 00:00:00 2001 From: Eugene Brodsky Date: Mon, 5 Feb 2024 11:51:01 -0500 Subject: [PATCH] fix(installer) slightly better typing for GPU selection --- installer/lib/installer.py | 23 +++++++++++----------- installer/lib/messages.py | 40 +++++++++++++++++++++----------------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/installer/lib/installer.py b/installer/lib/installer.py index c28e94d720..dcd542a805 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -368,25 +368,26 @@ def get_torch_source() -> Tuple[str | None, str | None]: :rtype: list """ - from messages import graphical_accelerator + from messages import select_gpu - # device can be one of: "cuda", "rocm", "cpu", "idk" - device = graphical_accelerator() + # device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect" + device = select_gpu() url = None optional_modules = "[onnx]" if OS == "Linux": - if device == "rocm": + if device.value == "rocm": url = "https://download.pytorch.org/whl/rocm5.6" - elif device == "cpu": + elif device.value == "cpu": url = "https://download.pytorch.org/whl/cpu" - if device == "cuda": - url = "https://download.pytorch.org/whl/cu121" - optional_modules = "[xformers,onnx-cuda]" - if device == "cuda_and_dml": - url = "https://download.pytorch.org/whl/cu121" - optional_modules = "[xformers,onnx-directml]" + elif OS == "Windows": + if device.value == "cuda": + url = "https://download.pytorch.org/whl/cu121" + optional_modules = "[xformers,onnx-cuda]" + if device.value == "cuda_and_dml": + url = "https://download.pytorch.org/whl/cu121" + optional_modules = "[xformers,onnx-directml]" # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 diff --git a/installer/lib/messages.py b/installer/lib/messages.py index c2015e6678..954478ba6c 100644 --- a/installer/lib/messages.py +++ b/installer/lib/messages.py @@ -5,6 +5,7 @@ Installer user interaction import os import platform +from enum import Enum from pathlib import Path from prompt_toolkit import HTML, prompt @@ -182,39 +183,42 @@ def dest_path(dest=None) -> Path | None: console.rule("Goodbye!") -def graphical_accelerator(): +class GpuType(Enum): + CUDA = "cuda" + CUDA_AND_DML = "cuda_and_dml" + ROCM = "rocm" + CPU = "cpu" + AUTODETECT = "autodetect" + + +def select_gpu() -> GpuType: """ - Prompt the user to select the graphical accelerator in their system - This does not validate user's choices (yet), but only offers choices - valid for the platform. - CUDA is the fallback. - We may be able to detect the GPU driver by shelling out to `modprobe` or `lspci`, - but this is not yet supported or reliable. Also, some users may have exotic preferences. + Prompt the user to select the GPU driver """ if ARCH == "arm64" and OS != "Darwin": print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.") - return "cpu" + return GpuType.CPU nvidia = ( "an [gold1 b]NVIDIA[/] GPU (using CUDA™)", - "cuda", + GpuType.CUDA, ) nvidia_with_dml = ( "an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA", - "cuda_and_dml", + GpuType.CUDA_AND_DML, ) amd = ( "an [gold1 b]AMD[/] GPU (using ROCm™)", - "rocm", + GpuType.ROCM, ) cpu = ( - "no compatible GPU, or specifically prefer to use the CPU", - "cpu", + "Do not install any GPU support, use CPU for generation (slow)", + GpuType.CPU, ) - idk = ( + autodetect = ( "I'm not sure what to choose", - "idk", + GpuType.AUTODETECT, ) options = [] @@ -231,7 +235,7 @@ def graphical_accelerator(): return options[0][1] # "I don't know" is always added the last option - options.append(idk) # type: ignore + options.append(autodetect) # type: ignore options = {str(i): opt for i, opt in enumerate(options, 1)} @@ -266,9 +270,9 @@ def graphical_accelerator(): ), ) - if options[choice][1] == "idk": + if options[choice][1] is GpuType.AUTODETECT: console.print( - "No problem. We will try to install a version that [i]should[/i] be compatible. :crossed_fingers:" + "No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types." ) return options[choice][1]