mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
install xformers and triton when CUDA torch requested
This commit is contained in:
parent
6cbdd88fe2
commit
11ac50a6ea
@ -14,7 +14,7 @@ from tempfile import TemporaryDirectory
|
||||
from typing import Union
|
||||
|
||||
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
|
||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
INSTALLER_REQS = ["pip", "rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
||||
|
||||
OS = platform.uname().system
|
||||
@ -91,7 +91,7 @@ class Installer:
|
||||
venv_dir = self.mktemp_venv()
|
||||
pip = get_pip_from_venv(Path(venv_dir.name))
|
||||
|
||||
cmd = [pip, "install", "--require-virtualenv", "--use-pep517"]
|
||||
cmd = [pip, "install", "--require-virtualenv", "--use-pep517", "--upgrade"]
|
||||
cmd.extend(self.reqs)
|
||||
|
||||
try:
|
||||
@ -156,9 +156,11 @@ class Installer:
|
||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
|
||||
self.instance.install(extra_index_url = get_torch_source() if not yes_to_all else None)
|
||||
|
||||
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
|
||||
self.instance.install(
|
||||
extra_index_url,
|
||||
optional_modules,
|
||||
)
|
||||
# run through the configuration flow
|
||||
self.instance.configure()
|
||||
|
||||
@ -194,9 +196,9 @@ class InvokeAiInstance:
|
||||
|
||||
return (self.runtime, self.venv)
|
||||
|
||||
def install(self, extra_index_url=None):
|
||||
def install(self, extra_index_url=None, optional_modules=None):
|
||||
"""
|
||||
Install this instance, including depenencies and the app itself
|
||||
Install this instance, including dependencies and the app itself
|
||||
|
||||
:param extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||
:type extra_index_url: str
|
||||
@ -210,7 +212,7 @@ class InvokeAiInstance:
|
||||
self.install_torch(extra_index_url)
|
||||
|
||||
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||
self.install_app(extra_index_url)
|
||||
self.install_app(extra_index_url, optional_modules)
|
||||
|
||||
def install_torch(self, extra_index_url=None):
|
||||
"""
|
||||
@ -233,7 +235,7 @@ class InvokeAiInstance:
|
||||
& FG
|
||||
)
|
||||
|
||||
def install_app(self, extra_index_url=None):
|
||||
def install_app(self, extra_index_url=None, optional_modules=None):
|
||||
"""
|
||||
Install the application with pip.
|
||||
Supports installation from PyPi or from a local source directory.
|
||||
@ -271,7 +273,6 @@ class InvokeAiInstance:
|
||||
# will install from PyPp
|
||||
src = f"invokeai=={version}" if version is not None else "invokeai"
|
||||
|
||||
import messages
|
||||
from plumbum import FG, local
|
||||
|
||||
pip = local[self.pip]
|
||||
@ -281,7 +282,7 @@ class InvokeAiInstance:
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--use-pep517",
|
||||
src,
|
||||
str(src)+(optional_modules if optional_modules else ''),
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
extra_index_url,
|
||||
pre,
|
||||
@ -371,7 +372,7 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_torch_source() -> Union[str, None]:
|
||||
def get_torch_source() -> (Union[str, None],str):
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
@ -382,7 +383,7 @@ def get_torch_source() -> Union[str, None]:
|
||||
|
||||
A NoneType return means just go to PyPi.
|
||||
|
||||
:return: The list of arguments to pip pointing at the PyTorch wheel source, if available
|
||||
:return: tuple consisting of (extra index url or None, optional modules to load or None)
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
@ -392,12 +393,16 @@ def get_torch_source() -> Union[str, None]:
|
||||
device = graphical_accelerator()
|
||||
|
||||
url = None
|
||||
optional_modules = None
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.2"
|
||||
elif device == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == 'cuda':
|
||||
optional_modules = '[xformers]'
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
return url
|
||||
return (url, optional_modules)
|
||||
|
@ -89,6 +89,10 @@ dependencies = [
|
||||
"mkdocs-redirects==1.2.0",
|
||||
]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov"]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.16; sys_platform!='darwin'",
|
||||
"triton; sys_platform!='darwin'",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user