mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: add PyTorch extra-index-url to the updater command
This commit is contained in:
parent
45bf2c7da6
commit
701f14c1e3
@ -5,14 +5,14 @@ pip install <path_to_git_source>.
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
from importlib.metadata import PackageNotFoundError, distribution
|
from importlib.metadata import PackageNotFoundError, distribution, distributions
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
from rich import box, print
|
from rich import box, print
|
||||||
from rich.console import Console, group
|
from rich.console import Console, group
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.prompt import Prompt
|
from rich.prompt import Confirm, Prompt
|
||||||
from rich.style import Style
|
from rich.style import Style
|
||||||
|
|
||||||
from invokeai.version import __version__
|
from invokeai.version import __version__
|
||||||
@ -61,6 +61,65 @@ def get_pypi_versions():
|
|||||||
return latest_version, latest_release_candidate, versions
|
return latest_version, latest_release_candidate, versions
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_extra_index_url() -> str | None:
|
||||||
|
"""
|
||||||
|
Determine torch wheel source URL and optional modules based on the user's OS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resolved_url = None
|
||||||
|
|
||||||
|
# In all other cases (like MacOS (MPS) or Linux+CUDA), there is no need to specify the extra index URL.
|
||||||
|
torch_package_urls = {
|
||||||
|
"windows_cuda": "https://download.pytorch.org/whl/cu121",
|
||||||
|
"linux_rocm": "https://download.pytorch.org/whl/rocm5.6",
|
||||||
|
"linux_cpu": "https://download.pytorch.org/whl/cpu",
|
||||||
|
}
|
||||||
|
|
||||||
|
nvidia_packages_present = (
|
||||||
|
len([d.metadata["Name"] for d in distributions() if d.metadata["Name"].startswith("nvidia")]) > 0
|
||||||
|
)
|
||||||
|
device = "cuda" if nvidia_packages_present else None
|
||||||
|
manual_gpu_selection_prompt = (
|
||||||
|
"[bold]We tried and failed to guess your GPU capabilities[/] :thinking_face:. Please select the GPU type:"
|
||||||
|
)
|
||||||
|
|
||||||
|
if OS == "Linux":
|
||||||
|
if not device:
|
||||||
|
# do we even need to offer a CPU-only install option?
|
||||||
|
print(manual_gpu_selection_prompt)
|
||||||
|
print("1: NVIDIA (CUDA)")
|
||||||
|
print("2: AMD (ROCm)")
|
||||||
|
print("3: No GPU - CPU only")
|
||||||
|
answer = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
|
||||||
|
match answer:
|
||||||
|
case "1":
|
||||||
|
device = "cuda"
|
||||||
|
case "2":
|
||||||
|
device = "rocm"
|
||||||
|
case "3":
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
if device != "cuda":
|
||||||
|
resolved_url = torch_package_urls[f"linux_{device}"]
|
||||||
|
|
||||||
|
if OS == "Windows":
|
||||||
|
if not device:
|
||||||
|
print(manual_gpu_selection_prompt)
|
||||||
|
print("1: NVIDIA (CUDA)")
|
||||||
|
print("2: No GPU - CPU only")
|
||||||
|
answer = Prompt.ask("Your choice:", choices=["1", "2"], default="1")
|
||||||
|
match answer:
|
||||||
|
case "1":
|
||||||
|
device = "cuda"
|
||||||
|
case "2":
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
resolved_url = torch_package_urls[f"windows_{device}"]
|
||||||
|
|
||||||
|
return resolved_url
|
||||||
|
|
||||||
|
|
||||||
def welcome(latest_release: str, latest_prerelease: str):
|
def welcome(latest_release: str, latest_prerelease: str):
|
||||||
@group()
|
@group()
|
||||||
def text():
|
def text():
|
||||||
@ -123,9 +182,13 @@ def main():
|
|||||||
print(f":exclamation: [bold red]'{release}' is not a recognized InvokeAI release.[/red bold]")
|
print(f":exclamation: [bold red]'{release}' is not a recognized InvokeAI release.[/red bold]")
|
||||||
|
|
||||||
extras = get_extras()
|
extras = get_extras()
|
||||||
|
flags = []
|
||||||
|
if (index_url := get_torch_extra_index_url()) is not None:
|
||||||
|
flags.append(f"--extra-index-url {index_url}")
|
||||||
|
flags = " ".join(flags)
|
||||||
|
|
||||||
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
|
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
|
||||||
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade'
|
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade {flags}'
|
||||||
|
|
||||||
print("")
|
print("")
|
||||||
print("")
|
print("")
|
||||||
|
Loading…
Reference in New Issue
Block a user