From e158ad85349c133a032d642d43fdf52009f728a3 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 15 Mar 2023 15:45:48 -0700 Subject: [PATCH] deps: upgrade to PyTorch 2.0 (replaces xformers) --- installer/lib/installer.py | 3 +-- invokeai/backend/stable_diffusion/diffusers_pipeline.py | 3 ++- pyproject.toml | 8 ++------ 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/installer/lib/installer.py b/installer/lib/installer.py index 8ab512eee8..344fa12046 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -461,8 +461,7 @@ def get_torch_source() -> (Union[str, None],str): url = "https://download.pytorch.org/whl/cpu" if device == 'cuda': - url = 'https://download.pytorch.org/whl/cu117' - optional_modules = '[xformers]' + url = 'https://download.pytorch.org/whl/cu118' # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 152e079693..09d09bf3fb 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -531,7 +531,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id: str = None, additional_guidance: List[Callable] = None, ): - self._adjust_memory_efficient_attention(latents) + # FIXME: do we still use any slicing now that PyTorch 2.0 has scaled dot-product attention on all platforms? + # self._adjust_memory_efficient_attention(latents) if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) if additional_guidance is None: diff --git a/pyproject.toml b/pyproject.toml index 9534d0ce07..f5f7d558ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,10 +71,10 @@ dependencies = [ "scikit-image>=0.19", "send2trash", "test-tube>=0.7.5", - "torch>=1.13.1", + "torch~=2.0", "torchvision>=0.14.1", "torchmetrics", - "transformers~=4.26", + "transformers~=4.27", "uvicorn[standard]==0.20.0", "windows-curses; sys_platform=='win32'", ] @@ -90,10 +90,6 @@ dependencies = [ "pudb", ] "test" = ["pytest>6.0.0", "pytest-cov"] -"xformers" = [ - "xformers~=0.0.16; sys_platform!='darwin'", - "triton; sys_platform=='linux'", -] [project.scripts]