install xformers and triton when CUDA torch requested

This commit is contained in:
Lincoln Stein 2023-02-01 17:41:38 -05:00
parent 6cbdd88fe2
commit 11ac50a6ea
2 changed files with 23 additions and 14 deletions

View File

@ -14,7 +14,7 @@ from tempfile import TemporaryDirectory
from typing import Union
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
INSTALLER_REQS = ["pip", "rich", "semver", "requests", "plumbum", "prompt-toolkit"]
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
OS = platform.uname().system
@ -91,7 +91,7 @@ class Installer:
venv_dir = self.mktemp_venv()
pip = get_pip_from_venv(Path(venv_dir.name))
cmd = [pip, "install", "--require-virtualenv", "--use-pep517"]
cmd = [pip, "install", "--require-virtualenv", "--use-pep517", "--upgrade"]
cmd.extend(self.reqs)
try:
@ -156,9 +156,11 @@ class Installer:
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
# install dependencies and the InvokeAI application
self.instance.install(extra_index_url = get_torch_source() if not yes_to_all else None)
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
self.instance.install(
extra_index_url,
optional_modules,
)
# run through the configuration flow
self.instance.configure()
@ -194,9 +196,9 @@ class InvokeAiInstance:
return (self.runtime, self.venv)
def install(self, extra_index_url=None):
def install(self, extra_index_url=None, optional_modules=None):
"""
Install this instance, including depenencies and the app itself
Install this instance, including dependencies and the app itself
:param extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
:type extra_index_url: str
@ -210,7 +212,7 @@ class InvokeAiInstance:
self.install_torch(extra_index_url)
messages.simple_banner("Installing the InvokeAI Application :art:")
self.install_app(extra_index_url)
self.install_app(extra_index_url, optional_modules)
def install_torch(self, extra_index_url=None):
"""
@ -233,7 +235,7 @@ class InvokeAiInstance:
& FG
)
def install_app(self, extra_index_url=None):
def install_app(self, extra_index_url=None, optional_modules=None):
"""
Install the application with pip.
Supports installation from PyPi or from a local source directory.
@ -271,7 +273,6 @@ class InvokeAiInstance:
# will install from PyPp
src = f"invokeai=={version}" if version is not None else "invokeai"
import messages
from plumbum import FG, local
pip = local[self.pip]
@ -281,7 +282,7 @@ class InvokeAiInstance:
"install",
"--require-virtualenv",
"--use-pep517",
src,
str(src)+(optional_modules if optional_modules else ''),
"--extra-index-url" if extra_index_url is not None else None,
extra_index_url,
pre,
@ -371,7 +372,7 @@ def set_sys_path(venv_path: Path) -> None:
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
def get_torch_source() -> Union[str, None]:
def get_torch_source() -> (Union[str, None],str):
"""
Determine the extra index URL for pip to use for torch installation.
This depends on the OS and the graphics accelerator in use.
@ -382,7 +383,7 @@ def get_torch_source() -> Union[str, None]:
A NoneType return means just go to PyPi.
:return: The list of arguments to pip pointing at the PyTorch wheel source, if available
:return: tuple consisting of (extra index url or None, optional modules to load or None)
:rtype: list
"""
@ -392,12 +393,16 @@ def get_torch_source() -> Union[str, None]:
device = graphical_accelerator()
url = None
optional_modules = None
if OS == "Linux":
if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.2"
elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu"
if device == 'cuda':
optional_modules = '[xformers]'
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
return url
return (url, optional_modules)

View File

@ -89,6 +89,10 @@ dependencies = [
"mkdocs-redirects==1.2.0",
]
"test" = ["pytest>6.0.0", "pytest-cov"]
"xformers" = [
"xformers~=0.0.16; sys_platform!='darwin'",
"triton; sys_platform!='darwin'",
]
[project.scripts]