2023-02-01 04:46:36 +00:00
|
|
|
# Copyright (c) 2023 Eugene Brodsky (https://github.com/ebr)
|
2023-01-08 08:09:04 +00:00
|
|
|
"""
|
|
|
|
InvokeAI installer script
|
|
|
|
"""
|
|
|
|
|
2023-01-09 05:13:01 +00:00
|
|
|
import os
|
|
|
|
import platform
|
2024-03-26 03:24:06 +00:00
|
|
|
import re
|
2023-01-16 06:52:22 +00:00
|
|
|
import shutil
|
2023-01-08 08:09:04 +00:00
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
import venv
|
|
|
|
from pathlib import Path
|
2023-01-27 07:10:32 +00:00
|
|
|
from tempfile import TemporaryDirectory
|
2024-01-29 04:49:22 +00:00
|
|
|
from typing import Optional, Tuple
|
2023-01-08 08:09:04 +00:00
|
|
|
|
2023-09-28 13:28:41 +00:00
|
|
|
SUPPORTED_PYTHON = ">=3.10.0,<=3.11.100"
|
2023-02-02 06:18:02 +00:00
|
|
|
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
2023-02-01 03:25:56 +00:00
|
|
|
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
2023-01-09 05:13:01 +00:00
|
|
|
|
|
|
|
OS = platform.uname().system
|
|
|
|
ARCH = platform.uname().machine
|
|
|
|
VERSION = "latest"
|
2023-01-08 08:09:04 +00:00
|
|
|
|
2024-02-02 01:08:13 +00:00
|
|
|
|
2024-03-26 03:24:06 +00:00
|
|
|
def get_version_from_wheel_filename(wheel_filename: str) -> str:
|
|
|
|
match = re.search(r"-(\d+\.\d+\.\d+)", wheel_filename)
|
|
|
|
if match:
|
|
|
|
version = match.group(1)
|
|
|
|
return version
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Could not extract version from wheel filename: {wheel_filename}")
|
|
|
|
|
|
|
|
|
2023-01-08 08:09:04 +00:00
|
|
|
class Installer:
|
|
|
|
"""
|
|
|
|
Deploys an InvokeAI installation into a given path
|
|
|
|
"""
|
|
|
|
|
2024-02-03 01:52:14 +00:00
|
|
|
reqs: list[str] = INSTALLER_REQS
|
|
|
|
|
2023-01-08 08:09:04 +00:00
|
|
|
def __init__(self) -> None:
|
2023-01-28 08:10:07 +00:00
|
|
|
if os.getenv("VIRTUAL_ENV") is not None:
|
2023-02-04 02:42:00 +00:00
|
|
|
print("A virtual environment is already activated. Please 'deactivate' before installation.")
|
|
|
|
sys.exit(-1)
|
2023-01-28 08:10:07 +00:00
|
|
|
self.bootstrap()
|
2024-02-05 23:58:55 +00:00
|
|
|
self.available_releases = get_github_releases()
|
2023-01-08 08:09:04 +00:00
|
|
|
|
2024-03-26 02:34:00 +00:00
|
|
|
def mktemp_venv(self) -> TemporaryDirectory[str]:
|
2023-01-08 08:09:04 +00:00
|
|
|
"""
|
|
|
|
Creates a temporary virtual environment for the installer itself
|
|
|
|
|
|
|
|
:return: path to the created virtual environment directory
|
|
|
|
:rtype: TemporaryDirectory
|
|
|
|
"""
|
|
|
|
|
2023-01-14 06:50:11 +00:00
|
|
|
# Cleaning up temporary directories on Windows results in a race condition
|
|
|
|
# and a stack trace.
|
|
|
|
# `ignore_cleanup_errors` was only added in Python 3.10
|
2023-01-27 07:10:32 +00:00
|
|
|
if OS == "Windows" and int(platform.python_version_tuple()[1]) >= 10:
|
2023-02-01 03:25:56 +00:00
|
|
|
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX, ignore_cleanup_errors=True)
|
2023-01-14 06:50:11 +00:00
|
|
|
else:
|
2023-02-01 03:25:56 +00:00
|
|
|
venv_dir = TemporaryDirectory(prefix=BOOTSTRAP_VENV_PREFIX)
|
2023-01-14 06:50:11 +00:00
|
|
|
|
2023-01-08 08:09:04 +00:00
|
|
|
venv.create(venv_dir.name, with_pip=True)
|
|
|
|
self.venv_dir = venv_dir
|
2023-02-01 03:25:56 +00:00
|
|
|
set_sys_path(Path(venv_dir.name))
|
2023-01-09 18:30:34 +00:00
|
|
|
|
2023-01-08 08:09:04 +00:00
|
|
|
return venv_dir
|
|
|
|
|
2024-03-26 02:34:00 +00:00
|
|
|
def bootstrap(self, verbose: bool = False) -> TemporaryDirectory[str] | None:
|
2023-01-08 08:09:04 +00:00
|
|
|
"""
|
|
|
|
Bootstrap the installer venv with packages required at install time
|
|
|
|
"""
|
|
|
|
|
2023-01-09 05:13:01 +00:00
|
|
|
print("Initializing the installer. This may take a minute - please wait...")
|
2023-01-08 08:09:04 +00:00
|
|
|
|
2023-01-09 05:13:01 +00:00
|
|
|
venv_dir = self.mktemp_venv()
|
2023-01-30 04:39:14 +00:00
|
|
|
pip = get_pip_from_venv(Path(venv_dir.name))
|
2023-01-08 08:09:04 +00:00
|
|
|
|
2023-02-02 05:28:38 +00:00
|
|
|
cmd = [pip, "install", "--require-virtualenv", "--use-pep517"]
|
2023-01-08 08:09:04 +00:00
|
|
|
cmd.extend(self.reqs)
|
|
|
|
|
|
|
|
try:
|
2024-02-02 01:08:13 +00:00
|
|
|
# upgrade pip to the latest version to avoid a confusing message
|
2024-02-06 14:35:24 +00:00
|
|
|
res = upgrade_pip(Path(venv_dir.name))
|
2024-02-02 01:08:13 +00:00
|
|
|
if verbose:
|
|
|
|
print(res)
|
|
|
|
|
|
|
|
# run the install prerequisites installation
|
2023-01-08 08:09:04 +00:00
|
|
|
res = subprocess.check_output(cmd).decode()
|
2024-02-02 01:08:13 +00:00
|
|
|
|
2023-01-08 08:09:04 +00:00
|
|
|
if verbose:
|
|
|
|
print(res)
|
2024-02-02 01:08:13 +00:00
|
|
|
|
2023-01-08 08:09:04 +00:00
|
|
|
return venv_dir
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(e)
|
|
|
|
|
2024-03-26 02:34:00 +00:00
|
|
|
def app_venv(self, venv_parent: Path) -> Path:
|
2023-01-09 08:09:56 +00:00
|
|
|
"""
|
|
|
|
Create a virtualenv for the InvokeAI installation
|
|
|
|
"""
|
|
|
|
|
2024-02-03 01:52:14 +00:00
|
|
|
venv_dir = venv_parent / ".venv"
|
2023-01-09 08:09:56 +00:00
|
|
|
|
2023-02-03 05:36:26 +00:00
|
|
|
# Prefer to copy python executables
|
|
|
|
# so that updates to system python don't break InvokeAI
|
|
|
|
try:
|
|
|
|
venv.create(venv_dir, with_pip=True)
|
|
|
|
# If installing over an existing environment previously created with symlinks,
|
|
|
|
# the executables will fail to copy. Keep symlinks in that case
|
|
|
|
except shutil.SameFileError:
|
|
|
|
venv.create(venv_dir, with_pip=True, symlinks=True)
|
2023-02-02 06:30:47 +00:00
|
|
|
|
2023-01-09 08:09:56 +00:00
|
|
|
return venv_dir
|
|
|
|
|
2023-06-28 21:47:04 +00:00
|
|
|
def install(
|
2024-03-26 02:34:00 +00:00
|
|
|
self,
|
|
|
|
root: str = "~/invokeai",
|
|
|
|
yes_to_all: bool = False,
|
|
|
|
find_links: Optional[str] = None,
|
2024-03-26 03:24:06 +00:00
|
|
|
wheel: Optional[Path] = None,
|
2023-06-28 21:47:04 +00:00
|
|
|
) -> None:
|
2024-03-26 03:29:50 +00:00
|
|
|
"""Install the InvokeAI application into the given runtime path
|
|
|
|
|
|
|
|
Args:
|
|
|
|
root: Destination path for the installation
|
|
|
|
yes_to_all: Accept defaults to all questions
|
|
|
|
find_links: A local directory to search for requirement wheels before going to remote indexes
|
|
|
|
wheel: A wheel file to install
|
2023-01-08 08:09:04 +00:00
|
|
|
"""
|
|
|
|
|
2023-01-13 09:11:23 +00:00
|
|
|
import messages
|
2023-01-08 08:09:04 +00:00
|
|
|
|
2024-03-26 03:24:06 +00:00
|
|
|
if wheel:
|
|
|
|
messages.installing_from_wheel(wheel.name)
|
|
|
|
version = get_version_from_wheel_filename(wheel.name)
|
|
|
|
else:
|
|
|
|
messages.welcome(self.available_releases)
|
|
|
|
version = messages.choose_version(self.available_releases)
|
2023-01-09 08:09:56 +00:00
|
|
|
|
2024-02-03 01:52:14 +00:00
|
|
|
auto_dest = Path(os.environ.get("INVOKEAI_ROOT", root)).expanduser().resolve()
|
|
|
|
destination = auto_dest if yes_to_all else messages.dest_path(root)
|
|
|
|
if destination is None:
|
2024-01-29 04:49:22 +00:00
|
|
|
print("Could not find or create the destination directory. Installation cancelled.")
|
|
|
|
sys.exit(0)
|
2023-01-08 08:09:04 +00:00
|
|
|
|
2023-01-19 20:34:33 +00:00
|
|
|
# create the venv for the app
|
2024-02-03 01:52:14 +00:00
|
|
|
self.venv = self.app_venv(venv_parent=destination)
|
2023-01-09 08:09:56 +00:00
|
|
|
|
2024-02-03 01:52:14 +00:00
|
|
|
self.instance = InvokeAiInstance(runtime=destination, venv=self.venv, version=version)
|
2023-01-09 08:09:56 +00:00
|
|
|
|
2023-01-19 20:34:33 +00:00
|
|
|
# install dependencies and the InvokeAI application
|
2023-02-01 22:41:38 +00:00
|
|
|
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
2024-03-26 03:24:06 +00:00
|
|
|
self.instance.install(extra_index_url, optional_modules, find_links, wheel)
|
2023-01-10 03:19:38 +00:00
|
|
|
|
2023-01-19 20:34:33 +00:00
|
|
|
# install the launch/update scripts into the runtime directory
|
2023-01-16 06:52:22 +00:00
|
|
|
self.instance.install_user_scripts()
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-01-09 08:09:56 +00:00
|
|
|
class InvokeAiInstance:
|
2023-01-08 08:09:04 +00:00
|
|
|
"""
|
2023-01-09 18:30:34 +00:00
|
|
|
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
|
|
|
The virtual environment *may* reside within the runtime directory.
|
|
|
|
A single runtime directory *may* be shared by multiple virtual environments, though this isn't currently tested or supported.
|
2023-01-08 08:09:04 +00:00
|
|
|
"""
|
|
|
|
|
2024-02-05 23:58:55 +00:00
|
|
|
def __init__(self, runtime: Path, venv: Path, version: str = "stable") -> None:
|
2023-01-09 08:09:56 +00:00
|
|
|
self.runtime = runtime
|
|
|
|
self.venv = venv
|
2023-01-30 04:39:14 +00:00
|
|
|
self.pip = get_pip_from_venv(venv)
|
2023-01-27 07:10:32 +00:00
|
|
|
self.version = version
|
2023-01-09 18:30:34 +00:00
|
|
|
|
2023-02-01 03:25:56 +00:00
|
|
|
set_sys_path(venv)
|
2023-01-09 08:09:56 +00:00
|
|
|
os.environ["INVOKEAI_ROOT"] = str(self.runtime.expanduser().resolve())
|
|
|
|
os.environ["VIRTUAL_ENV"] = str(self.venv.expanduser().resolve())
|
2024-02-06 14:35:24 +00:00
|
|
|
upgrade_pip(venv)
|
2023-01-09 08:09:56 +00:00
|
|
|
|
|
|
|
def get(self) -> tuple[Path, Path]:
|
|
|
|
"""
|
|
|
|
Get the location of the virtualenv directory for this installation
|
|
|
|
|
|
|
|
:return: Paths of the runtime and the venv directory
|
|
|
|
:rtype: tuple[Path, Path]
|
|
|
|
"""
|
|
|
|
|
|
|
|
return (self.runtime, self.venv)
|
|
|
|
|
2024-03-26 02:34:00 +00:00
|
|
|
def install(
|
|
|
|
self,
|
|
|
|
extra_index_url: Optional[str] = None,
|
|
|
|
optional_modules: Optional[str] = None,
|
|
|
|
find_links: Optional[str] = None,
|
2024-03-26 03:24:06 +00:00
|
|
|
wheel: Optional[Path] = None,
|
2024-03-26 02:34:00 +00:00
|
|
|
):
|
2024-03-26 03:29:50 +00:00
|
|
|
"""Install the package from PyPi or a wheel, if provided.
|
2023-02-02 00:03:15 +00:00
|
|
|
|
2024-03-26 03:29:50 +00:00
|
|
|
Args:
|
|
|
|
extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
|
|
|
optional_modules: optional modules to install using "[module1,module2]" format.
|
|
|
|
find_links: path to a directory containing wheels to be searched prior to going to the internet
|
|
|
|
wheel: a wheel file to install
|
2023-01-13 09:11:23 +00:00
|
|
|
"""
|
|
|
|
|
2024-02-05 23:58:55 +00:00
|
|
|
import messages
|
|
|
|
|
|
|
|
# not currently used, but may be useful for "install most recent version" option
|
|
|
|
if self.version == "prerelease":
|
|
|
|
version = None
|
|
|
|
pre_flag = "--pre"
|
|
|
|
elif self.version == "stable":
|
2023-01-27 07:10:32 +00:00
|
|
|
version = None
|
2024-02-05 23:58:55 +00:00
|
|
|
pre_flag = None
|
2023-01-27 07:10:32 +00:00
|
|
|
else:
|
|
|
|
version = self.version
|
2024-02-05 23:58:55 +00:00
|
|
|
pre_flag = None
|
2023-01-27 07:10:32 +00:00
|
|
|
|
2024-02-05 23:58:55 +00:00
|
|
|
src = "invokeai"
|
|
|
|
if optional_modules:
|
|
|
|
src += optional_modules
|
|
|
|
if version:
|
|
|
|
src += f"=={version}"
|
|
|
|
|
|
|
|
messages.simple_banner("Installing the InvokeAI Application :art:")
|
2023-01-13 09:11:23 +00:00
|
|
|
|
2024-03-26 02:34:00 +00:00
|
|
|
from plumbum import FG, ProcessExecutionError, local
|
2023-01-13 09:11:23 +00:00
|
|
|
|
|
|
|
pip = local[self.pip]
|
2024-02-03 01:53:18 +00:00
|
|
|
|
|
|
|
pipeline = pip[
|
|
|
|
"install",
|
|
|
|
"--require-virtualenv",
|
|
|
|
"--force-reinstall",
|
|
|
|
"--use-pep517",
|
2024-03-26 03:24:06 +00:00
|
|
|
str(src) if not wheel else str(wheel),
|
2024-02-03 01:53:18 +00:00
|
|
|
"--find-links" if find_links is not None else None,
|
|
|
|
find_links,
|
|
|
|
"--extra-index-url" if extra_index_url is not None else None,
|
|
|
|
extra_index_url,
|
2024-03-26 03:29:50 +00:00
|
|
|
pre_flag if not wheel else None, # Ignore the flag if we are installing a wheel
|
2024-02-03 01:53:18 +00:00
|
|
|
]
|
2023-01-13 09:11:23 +00:00
|
|
|
|
2024-02-03 01:53:18 +00:00
|
|
|
try:
|
|
|
|
_ = pipeline & FG
|
|
|
|
except ProcessExecutionError as e:
|
|
|
|
print(f"Error: {e}")
|
|
|
|
print(
|
|
|
|
"Could not install InvokeAI. Please try downloading the latest version of the installer and install again."
|
|
|
|
)
|
|
|
|
sys.exit(1)
|
2023-01-09 18:30:34 +00:00
|
|
|
|
2023-01-16 06:52:22 +00:00
|
|
|
def install_user_scripts(self):
|
|
|
|
"""
|
|
|
|
Copy the launch and update scripts to the runtime dir
|
|
|
|
"""
|
|
|
|
|
2023-01-27 07:10:32 +00:00
|
|
|
ext = "bat" if OS == "Windows" else "sh"
|
2023-01-16 06:52:22 +00:00
|
|
|
|
2023-02-03 23:14:40 +00:00
|
|
|
scripts = ["invoke"]
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-02-03 23:14:40 +00:00
|
|
|
for script in scripts:
|
2023-02-07 21:35:22 +00:00
|
|
|
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
2023-01-16 06:52:22 +00:00
|
|
|
dest = self.runtime / f"{script}.{ext}"
|
2023-01-27 07:10:32 +00:00
|
|
|
shutil.copy(src, dest)
|
2023-01-16 06:52:22 +00:00
|
|
|
os.chmod(dest, 0o0755)
|
|
|
|
|
2023-01-09 18:30:34 +00:00
|
|
|
def update(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def remove(self):
|
|
|
|
pass
|
2023-01-09 08:09:56 +00:00
|
|
|
|
|
|
|
|
2023-01-09 18:30:34 +00:00
|
|
|
### Utility functions ###
|
2023-01-09 08:09:56 +00:00
|
|
|
|
|
|
|
|
2023-01-30 04:39:14 +00:00
|
|
|
def get_pip_from_venv(venv_path: Path) -> str:
|
2023-01-09 18:30:34 +00:00
|
|
|
"""
|
|
|
|
Given a path to a virtual environment, get the absolute path to the `pip` executable
|
|
|
|
in a cross-platform fashion. Does not validate that the pip executable
|
|
|
|
actually exists in the virtualenv.
|
|
|
|
|
|
|
|
:param venv_path: Path to the virtual environment
|
|
|
|
:type venv_path: Path
|
|
|
|
:return: Absolute path to the pip executable
|
|
|
|
:rtype: str
|
|
|
|
"""
|
2023-01-09 08:09:56 +00:00
|
|
|
|
2023-08-17 22:45:25 +00:00
|
|
|
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
2023-01-27 07:10:32 +00:00
|
|
|
return str(venv_path.expanduser().resolve() / pip)
|
2023-01-09 08:09:56 +00:00
|
|
|
|
|
|
|
|
2024-02-06 14:35:24 +00:00
|
|
|
def upgrade_pip(venv_path: Path) -> str | None:
|
|
|
|
"""
|
|
|
|
Upgrade the pip executable in the given virtual environment
|
|
|
|
"""
|
|
|
|
|
|
|
|
python = "Scripts\\python.exe" if OS == "Windows" else "bin/python"
|
|
|
|
python = str(venv_path.expanduser().resolve() / python)
|
|
|
|
|
|
|
|
try:
|
|
|
|
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode()
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(e)
|
|
|
|
result = None
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2023-02-01 03:25:56 +00:00
|
|
|
def set_sys_path(venv_path: Path) -> None:
|
2023-01-09 18:30:34 +00:00
|
|
|
"""
|
2023-02-01 03:25:56 +00:00
|
|
|
Given a path to a virtual environment, set the sys.path, in a cross-platform fashion,
|
|
|
|
such that packages from the given venv may be imported in the current process.
|
|
|
|
Ensure that the packages from system environment are not visible (emulate
|
|
|
|
the virtual env 'activate' script) - this doesn't work on Windows yet.
|
2023-01-09 18:30:34 +00:00
|
|
|
|
|
|
|
:param venv_path: Path to the virtual environment
|
|
|
|
:type venv_path: Path
|
|
|
|
"""
|
|
|
|
|
2023-02-01 03:25:56 +00:00
|
|
|
# filter out any paths in sys.path that may be system- or user-wide
|
|
|
|
# but leave the temporary bootstrap virtualenv as it contains packages we
|
|
|
|
# temporarily need at install time
|
|
|
|
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
|
|
|
|
|
|
|
# determine site-packages/lib directory location for the venv
|
2023-01-09 18:30:34 +00:00
|
|
|
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
2023-02-01 03:25:56 +00:00
|
|
|
|
|
|
|
# add the site-packages location to the venv
|
2023-01-27 07:10:32 +00:00
|
|
|
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
2023-01-13 09:11:23 +00:00
|
|
|
|
|
|
|
|
2024-03-26 02:34:00 +00:00
|
|
|
def get_github_releases() -> tuple[list[str], list[str]] | None:
|
2024-02-05 20:34:36 +00:00
|
|
|
"""
|
|
|
|
Query Github for published (pre-)release versions.
|
|
|
|
Return a tuple where the first element is a list of stable releases and the second element is a list of pre-releases.
|
|
|
|
Return None if the query fails for any reason.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
## get latest releases using github api
|
|
|
|
url = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
2024-03-26 02:34:00 +00:00
|
|
|
releases: list[str] = []
|
|
|
|
pre_releases: list[str] = []
|
2024-02-05 20:34:36 +00:00
|
|
|
try:
|
|
|
|
res = requests.get(url)
|
|
|
|
res.raise_for_status()
|
|
|
|
tag_info = res.json()
|
|
|
|
for tag in tag_info:
|
|
|
|
if not tag["prerelease"]:
|
|
|
|
releases.append(tag["tag_name"].lstrip("v"))
|
|
|
|
else:
|
|
|
|
pre_releases.append(tag["tag_name"].lstrip("v"))
|
|
|
|
except requests.HTTPError as e:
|
|
|
|
print(f"Error: {e}")
|
|
|
|
print("Could not fetch version information from GitHub. Please check your network connection and try again.")
|
|
|
|
return
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error: {e}")
|
|
|
|
print("An unexpected error occurred while trying to fetch version information from GitHub. Please try again.")
|
|
|
|
return
|
|
|
|
|
|
|
|
releases.sort(reverse=True)
|
|
|
|
pre_releases.sort(reverse=True)
|
|
|
|
|
|
|
|
return releases, pre_releases
|
|
|
|
|
|
|
|
|
2024-01-29 04:49:22 +00:00
|
|
|
def get_torch_source() -> Tuple[str | None, str | None]:
|
2023-01-13 09:11:23 +00:00
|
|
|
"""
|
|
|
|
Determine the extra index URL for pip to use for torch installation.
|
|
|
|
This depends on the OS and the graphics accelerator in use.
|
|
|
|
This is only applicable to Windows and Linux, since PyTorch does not
|
|
|
|
offer accelerated builds for macOS.
|
|
|
|
|
2023-01-19 20:34:33 +00:00
|
|
|
Prefer CUDA-enabled wheels if the user wasn't sure of their GPU, as it will fallback to CPU if possible.
|
2023-01-13 09:11:23 +00:00
|
|
|
|
|
|
|
A NoneType return means just go to PyPi.
|
|
|
|
|
2023-02-01 22:41:38 +00:00
|
|
|
:return: tuple consisting of (extra index url or None, optional modules to load or None)
|
2023-01-13 09:11:23 +00:00
|
|
|
:rtype: list
|
|
|
|
"""
|
|
|
|
|
2024-02-05 16:51:01 +00:00
|
|
|
from messages import select_gpu
|
2023-01-13 09:11:23 +00:00
|
|
|
|
2024-02-05 16:51:01 +00:00
|
|
|
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
|
|
|
device = select_gpu()
|
2023-01-13 09:11:23 +00:00
|
|
|
|
|
|
|
url = None
|
2023-07-29 01:02:48 +00:00
|
|
|
optional_modules = "[onnx]"
|
2023-01-13 09:11:23 +00:00
|
|
|
if OS == "Linux":
|
2024-02-05 16:51:01 +00:00
|
|
|
if device.value == "rocm":
|
2024-01-25 07:49:55 +00:00
|
|
|
url = "https://download.pytorch.org/whl/rocm5.6"
|
2024-02-05 16:51:01 +00:00
|
|
|
elif device.value == "cpu":
|
2023-01-13 09:11:23 +00:00
|
|
|
url = "https://download.pytorch.org/whl/cpu"
|
|
|
|
|
2024-02-05 16:51:01 +00:00
|
|
|
elif OS == "Windows":
|
|
|
|
if device.value == "cuda":
|
|
|
|
url = "https://download.pytorch.org/whl/cu121"
|
|
|
|
optional_modules = "[xformers,onnx-cuda]"
|
|
|
|
if device.value == "cuda_and_dml":
|
|
|
|
url = "https://download.pytorch.org/whl/cu121"
|
|
|
|
optional_modules = "[xformers,onnx-directml]"
|
2023-02-01 22:41:38 +00:00
|
|
|
|
2023-01-19 20:34:33 +00:00
|
|
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
2023-01-13 09:11:23 +00:00
|
|
|
|
2023-02-01 22:41:38 +00:00
|
|
|
return (url, optional_modules)
|