fix: add PyTorch extra-index-url to the updater command

This commit is contained in:
Eugene Brodsky 2024-01-25 01:22:33 -05:00
parent 45bf2c7da6
commit 701f14c1e3

View File

@ -5,14 +5,14 @@ pip install <path_to_git_source>.
import os
import platform
from distutils.version import LooseVersion
from importlib.metadata import PackageNotFoundError, distribution
from importlib.metadata import PackageNotFoundError, distribution, distributions
import psutil
import requests
from rich import box, print
from rich.console import Console, group
from rich.panel import Panel
from rich.prompt import Prompt
from rich.prompt import Confirm, Prompt
from rich.style import Style
from invokeai.version import __version__
@ -61,6 +61,65 @@ def get_pypi_versions():
return latest_version, latest_release_candidate, versions
def get_torch_extra_index_url() -> str | None:
"""
Determine torch wheel source URL and optional modules based on the user's OS.
"""
resolved_url = None
# In all other cases (like MacOS (MPS) or Linux+CUDA), there is no need to specify the extra index URL.
torch_package_urls = {
"windows_cuda": "https://download.pytorch.org/whl/cu121",
"linux_rocm": "https://download.pytorch.org/whl/rocm5.6",
"linux_cpu": "https://download.pytorch.org/whl/cpu",
}
nvidia_packages_present = (
len([d.metadata["Name"] for d in distributions() if d.metadata["Name"].startswith("nvidia")]) > 0
)
device = "cuda" if nvidia_packages_present else None
manual_gpu_selection_prompt = (
"[bold]We tried and failed to guess your GPU capabilities[/] :thinking_face:. Please select the GPU type:"
)
if OS == "Linux":
if not device:
# do we even need to offer a CPU-only install option?
print(manual_gpu_selection_prompt)
print("1: NVIDIA (CUDA)")
print("2: AMD (ROCm)")
print("3: No GPU - CPU only")
answer = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
match answer:
case "1":
device = "cuda"
case "2":
device = "rocm"
case "3":
device = "cpu"
if device != "cuda":
resolved_url = torch_package_urls[f"linux_{device}"]
if OS == "Windows":
if not device:
print(manual_gpu_selection_prompt)
print("1: NVIDIA (CUDA)")
print("2: No GPU - CPU only")
answer = Prompt.ask("Your choice:", choices=["1", "2"], default="1")
match answer:
case "1":
device = "cuda"
case "2":
device = "cpu"
if device == "cuda":
resolved_url = torch_package_urls[f"windows_{device}"]
return resolved_url
def welcome(latest_release: str, latest_prerelease: str):
@group()
def text():
@ -123,9 +182,13 @@ def main():
print(f":exclamation: [bold red]'{release}' is not a recognized InvokeAI release.[/red bold]")
extras = get_extras()
flags = []
if (index_url := get_torch_extra_index_url()) is not None:
flags.append(f"--extra-index-url {index_url}")
flags = " ".join(flags)
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade'
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade {flags}'
print("")
print("")