mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(installer) slightly better typing for GPU selection
This commit is contained in:
parent
ca2bb6f0cc
commit
29bcc4b595
@ -368,23 +368,24 @@ def get_torch_source() -> Tuple[str | None, str | None]:
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
from messages import graphical_accelerator
|
||||
from messages import select_gpu
|
||||
|
||||
# device can be one of: "cuda", "rocm", "cpu", "idk"
|
||||
device = graphical_accelerator()
|
||||
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||
device = select_gpu()
|
||||
|
||||
url = None
|
||||
optional_modules = "[onnx]"
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
if device.value == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.6"
|
||||
elif device == "cpu":
|
||||
elif device.value == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == "cuda":
|
||||
elif OS == "Windows":
|
||||
if device.value == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
if device.value == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
|
@ -5,6 +5,7 @@ Installer user interaction
|
||||
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit import HTML, prompt
|
||||
@ -182,39 +183,42 @@ def dest_path(dest=None) -> Path | None:
|
||||
console.rule("Goodbye!")
|
||||
|
||||
|
||||
def graphical_accelerator():
|
||||
class GpuType(Enum):
|
||||
CUDA = "cuda"
|
||||
CUDA_AND_DML = "cuda_and_dml"
|
||||
ROCM = "rocm"
|
||||
CPU = "cpu"
|
||||
AUTODETECT = "autodetect"
|
||||
|
||||
|
||||
def select_gpu() -> GpuType:
|
||||
"""
|
||||
Prompt the user to select the graphical accelerator in their system
|
||||
This does not validate user's choices (yet), but only offers choices
|
||||
valid for the platform.
|
||||
CUDA is the fallback.
|
||||
We may be able to detect the GPU driver by shelling out to `modprobe` or `lspci`,
|
||||
but this is not yet supported or reliable. Also, some users may have exotic preferences.
|
||||
Prompt the user to select the GPU driver
|
||||
"""
|
||||
|
||||
if ARCH == "arm64" and OS != "Darwin":
|
||||
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
|
||||
return "cpu"
|
||||
return GpuType.CPU
|
||||
|
||||
nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||
"cuda",
|
||||
GpuType.CUDA,
|
||||
)
|
||||
nvidia_with_dml = (
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||
"cuda_and_dml",
|
||||
GpuType.CUDA_AND_DML,
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||
"rocm",
|
||||
GpuType.ROCM,
|
||||
)
|
||||
cpu = (
|
||||
"no compatible GPU, or specifically prefer to use the CPU",
|
||||
"cpu",
|
||||
"Do not install any GPU support, use CPU for generation (slow)",
|
||||
GpuType.CPU,
|
||||
)
|
||||
idk = (
|
||||
autodetect = (
|
||||
"I'm not sure what to choose",
|
||||
"idk",
|
||||
GpuType.AUTODETECT,
|
||||
)
|
||||
|
||||
options = []
|
||||
@ -231,7 +235,7 @@ def graphical_accelerator():
|
||||
return options[0][1]
|
||||
|
||||
# "I don't know" is always added the last option
|
||||
options.append(idk) # type: ignore
|
||||
options.append(autodetect) # type: ignore
|
||||
|
||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||
|
||||
@ -266,9 +270,9 @@ def graphical_accelerator():
|
||||
),
|
||||
)
|
||||
|
||||
if options[choice][1] == "idk":
|
||||
if options[choice][1] is GpuType.AUTODETECT:
|
||||
console.print(
|
||||
"No problem. We will try to install a version that [i]should[/i] be compatible. :crossed_fingers:"
|
||||
"No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types."
|
||||
)
|
||||
|
||||
return options[choice][1]
|
||||
|
Loading…
Reference in New Issue
Block a user