mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
(installer) install PyTorch from correct repositories
This commit is contained in:
parent
b186965e77
commit
169c56e471
@ -9,6 +9,7 @@ import sys
|
||||
import venv
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory, TemporaryFile
|
||||
from typing import Union
|
||||
|
||||
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
|
||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
@ -140,17 +141,17 @@ class Installer:
|
||||
:type version: str
|
||||
"""
|
||||
|
||||
from messages import dest_path, welcome
|
||||
import messages
|
||||
|
||||
welcome()
|
||||
messages.welcome()
|
||||
|
||||
self.dest = dest_path(path)
|
||||
self.dest = messages.dest_path(path)
|
||||
|
||||
self.venv = self.app_venv()
|
||||
|
||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv)
|
||||
|
||||
self.instance.deploy()
|
||||
self.instance.deploy(extra_index_url=get_torch_source())
|
||||
|
||||
self.instance.configure()
|
||||
|
||||
@ -182,13 +183,32 @@ class InvokeAiInstance:
|
||||
|
||||
return (self.runtime, self.venv)
|
||||
|
||||
def deploy(self):
|
||||
def deploy(self, extra_index_url=None):
|
||||
"""
|
||||
Install packages with pip
|
||||
|
||||
:param extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||
:type extra_index_url: str
|
||||
"""
|
||||
|
||||
### this is all very rough for now as a PoC
|
||||
### source installer basically
|
||||
### TODO: need to pull the source from Github like the current installer does
|
||||
### until we continuously build wheels
|
||||
|
||||
import messages
|
||||
from plumbum import local, FG
|
||||
|
||||
# pre-installing Torch because this is the most reliable way to ensure
|
||||
# the correct version gets installed.
|
||||
# this works with either source or wheel install and has
|
||||
# negligible impact on installation times.
|
||||
messages.simple_banner("Installing PyTorch :fire:")
|
||||
self.install_torch(extra_index_url)
|
||||
|
||||
messages.simple_banner("Installing InvokeAI base dependencies :rocket:")
|
||||
extra_index_url_arg = "--extra-index-url" if extra_index_url is not None else None
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
(
|
||||
@ -199,11 +219,46 @@ class InvokeAiInstance:
|
||||
(Path(__file__).parents[1] / "environments-and-requirements/requirements-base.txt")
|
||||
.expanduser()
|
||||
.resolve(),
|
||||
extra_index_url_arg,
|
||||
extra_index_url,
|
||||
]
|
||||
& FG
|
||||
)
|
||||
|
||||
(pip["install", "--require-virtualenv", Path(__file__).parents[1].expanduser().resolve()] & FG)
|
||||
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||
(
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
Path(__file__).parents[1].expanduser().resolve(),
|
||||
extra_index_url_arg,
|
||||
extra_index_url,
|
||||
]
|
||||
& FG
|
||||
)
|
||||
|
||||
def install_torch(self, extra_index_url=None):
|
||||
"""
|
||||
Install PyTorch
|
||||
"""
|
||||
|
||||
from plumbum import local, FG
|
||||
|
||||
extra_index_url_arg = "--extra-index-url" if extra_index_url is not None else None
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
(
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch",
|
||||
"torchvision",
|
||||
extra_index_url_arg,
|
||||
extra_index_url,
|
||||
]
|
||||
& FG
|
||||
)
|
||||
|
||||
def configure(self):
|
||||
"""
|
||||
@ -256,3 +311,40 @@ def add_venv_site(venv_path: Path) -> None:
|
||||
|
||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").absolute()))
|
||||
|
||||
|
||||
def get_torch_source() -> Union[str, None]:
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
This is only applicable to Windows and Linux, since PyTorch does not
|
||||
offer accelerated builds for macOS.
|
||||
|
||||
Prefer CUDA if the user wasn't sure of their GPU, as it will fallback to CPU if possible.
|
||||
|
||||
A NoneType return means just go to PyPi.
|
||||
|
||||
:return: The list of arguments to pip pointing at the PyTorch wheel source, if available
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
from messages import graphical_accelerator
|
||||
|
||||
device = graphical_accelerator()
|
||||
|
||||
url = None
|
||||
if OS == "Linux":
|
||||
if device in ["cuda", "idk"]:
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
elif device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.2"
|
||||
else:
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
elif OS == "Windows":
|
||||
if device in ["cuda", "idk"]:
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
|
||||
# ignoring macOS because its wheels come from PyPi anyway (cpu only)
|
||||
|
||||
return url
|
||||
|
Loading…
Reference in New Issue
Block a user