mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(installer) slightly better typing for GPU selection
This commit is contained in:
parent
ca2bb6f0cc
commit
29bcc4b595
@ -368,25 +368,26 @@ def get_torch_source() -> Tuple[str | None, str | None]:
|
|||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from messages import graphical_accelerator
|
from messages import select_gpu
|
||||||
|
|
||||||
# device can be one of: "cuda", "rocm", "cpu", "idk"
|
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||||
device = graphical_accelerator()
|
device = select_gpu()
|
||||||
|
|
||||||
url = None
|
url = None
|
||||||
optional_modules = "[onnx]"
|
optional_modules = "[onnx]"
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
if device == "rocm":
|
if device.value == "rocm":
|
||||||
url = "https://download.pytorch.org/whl/rocm5.6"
|
url = "https://download.pytorch.org/whl/rocm5.6"
|
||||||
elif device == "cpu":
|
elif device.value == "cpu":
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
if device == "cuda":
|
elif OS == "Windows":
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
if device.value == "cuda":
|
||||||
optional_modules = "[xformers,onnx-cuda]"
|
url = "https://download.pytorch.org/whl/cu121"
|
||||||
if device == "cuda_and_dml":
|
optional_modules = "[xformers,onnx-cuda]"
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
if device.value == "cuda_and_dml":
|
||||||
optional_modules = "[xformers,onnx-directml]"
|
url = "https://download.pytorch.org/whl/cu121"
|
||||||
|
optional_modules = "[xformers,onnx-directml]"
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ Installer user interaction
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from prompt_toolkit import HTML, prompt
|
from prompt_toolkit import HTML, prompt
|
||||||
@ -182,39 +183,42 @@ def dest_path(dest=None) -> Path | None:
|
|||||||
console.rule("Goodbye!")
|
console.rule("Goodbye!")
|
||||||
|
|
||||||
|
|
||||||
def graphical_accelerator():
|
class GpuType(Enum):
|
||||||
|
CUDA = "cuda"
|
||||||
|
CUDA_AND_DML = "cuda_and_dml"
|
||||||
|
ROCM = "rocm"
|
||||||
|
CPU = "cpu"
|
||||||
|
AUTODETECT = "autodetect"
|
||||||
|
|
||||||
|
|
||||||
|
def select_gpu() -> GpuType:
|
||||||
"""
|
"""
|
||||||
Prompt the user to select the graphical accelerator in their system
|
Prompt the user to select the GPU driver
|
||||||
This does not validate user's choices (yet), but only offers choices
|
|
||||||
valid for the platform.
|
|
||||||
CUDA is the fallback.
|
|
||||||
We may be able to detect the GPU driver by shelling out to `modprobe` or `lspci`,
|
|
||||||
but this is not yet supported or reliable. Also, some users may have exotic preferences.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if ARCH == "arm64" and OS != "Darwin":
|
if ARCH == "arm64" and OS != "Darwin":
|
||||||
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
|
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
|
||||||
return "cpu"
|
return GpuType.CPU
|
||||||
|
|
||||||
nvidia = (
|
nvidia = (
|
||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||||
"cuda",
|
GpuType.CUDA,
|
||||||
)
|
)
|
||||||
nvidia_with_dml = (
|
nvidia_with_dml = (
|
||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||||
"cuda_and_dml",
|
GpuType.CUDA_AND_DML,
|
||||||
)
|
)
|
||||||
amd = (
|
amd = (
|
||||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||||
"rocm",
|
GpuType.ROCM,
|
||||||
)
|
)
|
||||||
cpu = (
|
cpu = (
|
||||||
"no compatible GPU, or specifically prefer to use the CPU",
|
"Do not install any GPU support, use CPU for generation (slow)",
|
||||||
"cpu",
|
GpuType.CPU,
|
||||||
)
|
)
|
||||||
idk = (
|
autodetect = (
|
||||||
"I'm not sure what to choose",
|
"I'm not sure what to choose",
|
||||||
"idk",
|
GpuType.AUTODETECT,
|
||||||
)
|
)
|
||||||
|
|
||||||
options = []
|
options = []
|
||||||
@ -231,7 +235,7 @@ def graphical_accelerator():
|
|||||||
return options[0][1]
|
return options[0][1]
|
||||||
|
|
||||||
# "I don't know" is always added the last option
|
# "I don't know" is always added the last option
|
||||||
options.append(idk) # type: ignore
|
options.append(autodetect) # type: ignore
|
||||||
|
|
||||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||||
|
|
||||||
@ -266,9 +270,9 @@ def graphical_accelerator():
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if options[choice][1] == "idk":
|
if options[choice][1] is GpuType.AUTODETECT:
|
||||||
console.print(
|
console.print(
|
||||||
"No problem. We will try to install a version that [i]should[/i] be compatible. :crossed_fingers:"
|
"No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types."
|
||||||
)
|
)
|
||||||
|
|
||||||
return options[choice][1]
|
return options[choice][1]
|
||||||
|
Loading…
Reference in New Issue
Block a user