mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
7 Commits
next-test-
...
feat/torch
Author | SHA1 | Date | |
---|---|---|---|
c06e628ff6 | |||
125bd06e6b | |||
f6a5018786 | |||
22a9abdb70 | |||
fb0975bc61 | |||
9377e143a9 | |||
58f4ebc821 |
@ -462,6 +462,10 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
optional_modules = "[xformers]"
|
||||
|
||||
if OS == "Windows":
|
||||
if device == "directml":
|
||||
optional_modules = "[torch-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
return (url, optional_modules)
|
||||
|
@ -171,6 +171,10 @@ def graphical_accelerator():
|
||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||
"rocm",
|
||||
)
|
||||
directml = (
|
||||
"a GPU supporting [gold1 b]DirectML[/] with installed drivers",
|
||||
"directml",
|
||||
)
|
||||
cpu = (
|
||||
"no compatible GPU, or specifically prefer to use the CPU",
|
||||
"cpu",
|
||||
@ -181,7 +185,7 @@ def graphical_accelerator():
|
||||
)
|
||||
|
||||
if OS == "Windows":
|
||||
options = [nvidia, cpu]
|
||||
options = [nvidia, directml, cpu]
|
||||
if OS == "Linux":
|
||||
options = [nvidia, amd, cpu]
|
||||
elif OS == "Darwin":
|
||||
|
@ -339,6 +339,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
# TODO: use DML_DEVICE?
|
||||
if self.device.type == "privateuseone":
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
#self.unet.set_default_attn_processor()
|
||||
return
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
|
@ -19,6 +19,11 @@ def choose_torch_device() -> torch.device:
|
||||
return CPU_DEVICE
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
try:
|
||||
import torch_directml
|
||||
return torch_directml.device()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return CPU_DEVICE
|
||||
|
@ -18,6 +18,43 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
import diffusers
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
|
||||
try:
|
||||
import torch_directml
|
||||
torch_directml_installed = True
|
||||
torch_directml_device_type = torch_directml.device().type # privateuseone
|
||||
except:
|
||||
torch_directml_installed = False
|
||||
|
||||
if torch_directml_installed:
|
||||
from contextlib import contextmanager
|
||||
#torch.inference_mode = torch.no_grad
|
||||
|
||||
def empty_enter(self):
|
||||
pass
|
||||
|
||||
torch.inference_mode.__enter__ = empty_enter
|
||||
"""
|
||||
orig_inference_mode = torch.inference_mode
|
||||
@contextmanager
|
||||
def fake_inference_mode(mode):
|
||||
yield
|
||||
#try:
|
||||
# with torch.inference_mode(False):
|
||||
# yield
|
||||
#finally:
|
||||
# pass
|
||||
|
||||
torch.inference_mode = fake_inference_mode
|
||||
"""
|
||||
|
||||
|
||||
orig_torch_generator = torch.Generator
|
||||
def new_torch_Generator(device="cpu"):
|
||||
if torch_directml_installed and type(device) is torch.device and device.type == torch_directml_device_type:
|
||||
device = "cpu"
|
||||
return orig_torch_generator(device)
|
||||
torch.Generator = new_torch_Generator
|
||||
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
|
@ -98,7 +98,9 @@ dependencies = [
|
||||
"dev" = [
|
||||
"pudb",
|
||||
]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov"]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
|
||||
"torch-directml" = ["torch-directml"]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
|
Reference in New Issue
Block a user