(installer) install PyTorch from correct repositories

This commit is contained in:
Eugene Brodsky 2023-01-13 04:11:23 -05:00
parent b186965e77
commit 169c56e471

View File

@ -9,6 +9,7 @@ import sys
import venv import venv
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory, TemporaryFile from tempfile import TemporaryDirectory, TemporaryFile
from typing import Union
SUPPORTED_PYTHON = ">=3.9.0,<3.11" SUPPORTED_PYTHON = ">=3.9.0,<3.11"
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"] INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
@ -140,17 +141,17 @@ class Installer:
:type version: str :type version: str
""" """
from messages import dest_path, welcome import messages
welcome() messages.welcome()
self.dest = dest_path(path) self.dest = messages.dest_path(path)
self.venv = self.app_venv() self.venv = self.app_venv()
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv) self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv)
self.instance.deploy() self.instance.deploy(extra_index_url=get_torch_source())
self.instance.configure() self.instance.configure()
@ -182,13 +183,32 @@ class InvokeAiInstance:
return (self.runtime, self.venv) return (self.runtime, self.venv)
def deploy(self): def deploy(self, extra_index_url=None):
"""
Install packages with pip
:param extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
:type extra_index_url: str
"""
### this is all very rough for now as a PoC ### this is all very rough for now as a PoC
### source installer basically ### source installer basically
### TODO: need to pull the source from Github like the current installer does
### until we continuously build wheels
import messages
from plumbum import local, FG from plumbum import local, FG
# pre-installing Torch because this is the most reliable way to ensure
# the correct version gets installed.
# this works with either source or wheel install and has
# negligible impact on installation times.
messages.simple_banner("Installing PyTorch :fire:")
self.install_torch(extra_index_url)
messages.simple_banner("Installing InvokeAI base dependencies :rocket:")
extra_index_url_arg = "--extra-index-url" if extra_index_url is not None else None
pip = local[self.pip] pip = local[self.pip]
( (
@ -199,11 +219,46 @@ class InvokeAiInstance:
(Path(__file__).parents[1] / "environments-and-requirements/requirements-base.txt") (Path(__file__).parents[1] / "environments-and-requirements/requirements-base.txt")
.expanduser() .expanduser()
.resolve(), .resolve(),
extra_index_url_arg,
extra_index_url,
] ]
& FG & FG
) )
(pip["install", "--require-virtualenv", Path(__file__).parents[1].expanduser().resolve()] & FG) messages.simple_banner("Installing the InvokeAI Application :art:")
(
pip[
"install",
"--require-virtualenv",
Path(__file__).parents[1].expanduser().resolve(),
extra_index_url_arg,
extra_index_url,
]
& FG
)
def install_torch(self, extra_index_url=None):
"""
Install PyTorch
"""
from plumbum import local, FG
extra_index_url_arg = "--extra-index-url" if extra_index_url is not None else None
pip = local[self.pip]
(
pip[
"install",
"--require-virtualenv",
"torch",
"torchvision",
extra_index_url_arg,
extra_index_url,
]
& FG
)
def configure(self): def configure(self):
""" """
@ -256,3 +311,40 @@ def add_venv_site(venv_path: Path) -> None:
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}" lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
sys.path.append(str(Path(venv_path, lib, "site-packages").absolute())) sys.path.append(str(Path(venv_path, lib, "site-packages").absolute()))
def get_torch_source() -> Union[str, None]:
"""
Determine the extra index URL for pip to use for torch installation.
This depends on the OS and the graphics accelerator in use.
This is only applicable to Windows and Linux, since PyTorch does not
offer accelerated builds for macOS.
Prefer CUDA if the user wasn't sure of their GPU, as it will fallback to CPU if possible.
A NoneType return means just go to PyPi.
:return: The list of arguments to pip pointing at the PyTorch wheel source, if available
:rtype: list
"""
from messages import graphical_accelerator
device = graphical_accelerator()
url = None
if OS == "Linux":
if device in ["cuda", "idk"]:
url = "https://download.pytorch.org/whl/cu117"
elif device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.2"
else:
url = "https://download.pytorch.org/whl/cpu"
elif OS == "Windows":
if device in ["cuda", "idk"]:
url = "https://download.pytorch.org/whl/cu117"
# ignoring macOS because its wheels come from PyPi anyway (cpu only)
return url