deps: upgrade to PyTorch 2.0 (replaces xformers)

This commit is contained in:
Kevin Turner 2023-03-15 15:45:48 -07:00
parent 9738b0ff69
commit e158ad8534
3 changed files with 5 additions and 9 deletions

View File

@ -461,8 +461,7 @@ def get_torch_source() -> (Union[str, None],str):
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
if device == 'cuda': if device == 'cuda':
url = 'https://download.pytorch.org/whl/cu117' url = 'https://download.pytorch.org/whl/cu118'
optional_modules = '[xformers]'
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -531,7 +531,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
): ):
self._adjust_memory_efficient_attention(latents) # FIXME: do we still use any slicing now that PyTorch 2.0 has scaled dot-product attention on all platforms?
# self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH) run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: if additional_guidance is None:

View File

@ -71,10 +71,10 @@ dependencies = [
"scikit-image>=0.19", "scikit-image>=0.19",
"send2trash", "send2trash",
"test-tube>=0.7.5", "test-tube>=0.7.5",
"torch>=1.13.1", "torch~=2.0",
"torchvision>=0.14.1", "torchvision>=0.14.1",
"torchmetrics", "torchmetrics",
"transformers~=4.26", "transformers~=4.27",
"uvicorn[standard]==0.20.0", "uvicorn[standard]==0.20.0",
"windows-curses; sys_platform=='win32'", "windows-curses; sys_platform=='win32'",
] ]
@ -90,10 +90,6 @@ dependencies = [
"pudb", "pudb",
] ]
"test" = ["pytest>6.0.0", "pytest-cov"] "test" = ["pytest>6.0.0", "pytest-cov"]
"xformers" = [
"xformers~=0.0.16; sys_platform!='darwin'",
"triton; sys_platform=='linux'",
]
[project.scripts] [project.scripts]