mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
(installer) install PyTorch from correct repositories
This commit is contained in:
parent
b186965e77
commit
169c56e471
@ -9,6 +9,7 @@ import sys
|
|||||||
import venv
|
import venv
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory, TemporaryFile
|
from tempfile import TemporaryDirectory, TemporaryFile
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
|
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
|
||||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||||
@ -140,17 +141,17 @@ class Installer:
|
|||||||
:type version: str
|
:type version: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from messages import dest_path, welcome
|
import messages
|
||||||
|
|
||||||
welcome()
|
messages.welcome()
|
||||||
|
|
||||||
self.dest = dest_path(path)
|
self.dest = messages.dest_path(path)
|
||||||
|
|
||||||
self.venv = self.app_venv()
|
self.venv = self.app_venv()
|
||||||
|
|
||||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv)
|
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv)
|
||||||
|
|
||||||
self.instance.deploy()
|
self.instance.deploy(extra_index_url=get_torch_source())
|
||||||
|
|
||||||
self.instance.configure()
|
self.instance.configure()
|
||||||
|
|
||||||
@ -182,13 +183,32 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
return (self.runtime, self.venv)
|
return (self.runtime, self.venv)
|
||||||
|
|
||||||
def deploy(self):
|
def deploy(self, extra_index_url=None):
|
||||||
|
"""
|
||||||
|
Install packages with pip
|
||||||
|
|
||||||
|
:param extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||||
|
:type extra_index_url: str
|
||||||
|
"""
|
||||||
|
|
||||||
### this is all very rough for now as a PoC
|
### this is all very rough for now as a PoC
|
||||||
### source installer basically
|
### source installer basically
|
||||||
|
### TODO: need to pull the source from Github like the current installer does
|
||||||
|
### until we continuously build wheels
|
||||||
|
|
||||||
|
import messages
|
||||||
from plumbum import local, FG
|
from plumbum import local, FG
|
||||||
|
|
||||||
|
# pre-installing Torch because this is the most reliable way to ensure
|
||||||
|
# the correct version gets installed.
|
||||||
|
# this works with either source or wheel install and has
|
||||||
|
# negligible impact on installation times.
|
||||||
|
messages.simple_banner("Installing PyTorch :fire:")
|
||||||
|
self.install_torch(extra_index_url)
|
||||||
|
|
||||||
|
messages.simple_banner("Installing InvokeAI base dependencies :rocket:")
|
||||||
|
extra_index_url_arg = "--extra-index-url" if extra_index_url is not None else None
|
||||||
|
|
||||||
pip = local[self.pip]
|
pip = local[self.pip]
|
||||||
|
|
||||||
(
|
(
|
||||||
@ -199,11 +219,46 @@ class InvokeAiInstance:
|
|||||||
(Path(__file__).parents[1] / "environments-and-requirements/requirements-base.txt")
|
(Path(__file__).parents[1] / "environments-and-requirements/requirements-base.txt")
|
||||||
.expanduser()
|
.expanduser()
|
||||||
.resolve(),
|
.resolve(),
|
||||||
|
extra_index_url_arg,
|
||||||
|
extra_index_url,
|
||||||
]
|
]
|
||||||
& FG
|
& FG
|
||||||
)
|
)
|
||||||
|
|
||||||
(pip["install", "--require-virtualenv", Path(__file__).parents[1].expanduser().resolve()] & FG)
|
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||||
|
(
|
||||||
|
pip[
|
||||||
|
"install",
|
||||||
|
"--require-virtualenv",
|
||||||
|
Path(__file__).parents[1].expanduser().resolve(),
|
||||||
|
extra_index_url_arg,
|
||||||
|
extra_index_url,
|
||||||
|
]
|
||||||
|
& FG
|
||||||
|
)
|
||||||
|
|
||||||
|
def install_torch(self, extra_index_url=None):
|
||||||
|
"""
|
||||||
|
Install PyTorch
|
||||||
|
"""
|
||||||
|
|
||||||
|
from plumbum import local, FG
|
||||||
|
|
||||||
|
extra_index_url_arg = "--extra-index-url" if extra_index_url is not None else None
|
||||||
|
|
||||||
|
pip = local[self.pip]
|
||||||
|
|
||||||
|
(
|
||||||
|
pip[
|
||||||
|
"install",
|
||||||
|
"--require-virtualenv",
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
extra_index_url_arg,
|
||||||
|
extra_index_url,
|
||||||
|
]
|
||||||
|
& FG
|
||||||
|
)
|
||||||
|
|
||||||
def configure(self):
|
def configure(self):
|
||||||
"""
|
"""
|
||||||
@ -256,3 +311,40 @@ def add_venv_site(venv_path: Path) -> None:
|
|||||||
|
|
||||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||||
sys.path.append(str(Path(venv_path, lib, "site-packages").absolute()))
|
sys.path.append(str(Path(venv_path, lib, "site-packages").absolute()))
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_source() -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Determine the extra index URL for pip to use for torch installation.
|
||||||
|
This depends on the OS and the graphics accelerator in use.
|
||||||
|
This is only applicable to Windows and Linux, since PyTorch does not
|
||||||
|
offer accelerated builds for macOS.
|
||||||
|
|
||||||
|
Prefer CUDA if the user wasn't sure of their GPU, as it will fallback to CPU if possible.
|
||||||
|
|
||||||
|
A NoneType return means just go to PyPi.
|
||||||
|
|
||||||
|
:return: The list of arguments to pip pointing at the PyTorch wheel source, if available
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
|
||||||
|
from messages import graphical_accelerator
|
||||||
|
|
||||||
|
device = graphical_accelerator()
|
||||||
|
|
||||||
|
url = None
|
||||||
|
if OS == "Linux":
|
||||||
|
if device in ["cuda", "idk"]:
|
||||||
|
url = "https://download.pytorch.org/whl/cu117"
|
||||||
|
elif device == "rocm":
|
||||||
|
url = "https://download.pytorch.org/whl/rocm5.2"
|
||||||
|
else:
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
|
elif OS == "Windows":
|
||||||
|
if device in ["cuda", "idk"]:
|
||||||
|
url = "https://download.pytorch.org/whl/cu117"
|
||||||
|
|
||||||
|
# ignoring macOS because its wheels come from PyPi anyway (cpu only)
|
||||||
|
|
||||||
|
return url
|
||||||
|
Loading…
Reference in New Issue
Block a user