Compare commits

...

7 Commits

6 changed files with 59 additions and 1 deletions

View File

@ -462,6 +462,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers]"
if OS == "Windows":
if device == "directml":
optional_modules = "[torch-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
return (url, optional_modules)

View File

@ -171,6 +171,10 @@ def graphical_accelerator():
"an [gold1 b]AMD[/] GPU (using ROCm™)",
"rocm",
)
directml = (
"a GPU supporting [gold1 b]DirectML[/] with installed drivers",
"directml",
)
cpu = (
"no compatible GPU, or specifically prefer to use the CPU",
"cpu",
@ -181,7 +185,7 @@ def graphical_accelerator():
)
if OS == "Windows":
options = [nvidia, cpu]
options = [nvidia, directml, cpu]
if OS == "Linux":
options = [nvidia, amd, cpu]
elif OS == "Darwin":

View File

@ -339,6 +339,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
# TODO: use DML_DEVICE?
if self.device.type == "privateuseone":
self.enable_attention_slicing(slice_size="max")
#self.unet.set_default_attn_processor()
return
config = InvokeAIAppConfig.get_config()
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
self.enable_xformers_memory_efficient_attention()

View File

@ -19,6 +19,11 @@ def choose_torch_device() -> torch.device:
return CPU_DEVICE
if torch.cuda.is_available():
return torch.device("cuda")
try:
import torch_directml
return torch_directml.device()
except ModuleNotFoundError:
pass
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return CPU_DEVICE

View File

@ -18,6 +18,43 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
try:
import torch_directml
torch_directml_installed = True
torch_directml_device_type = torch_directml.device().type # privateuseone
except:
torch_directml_installed = False
if torch_directml_installed:
from contextlib import contextmanager
#torch.inference_mode = torch.no_grad
def empty_enter(self):
pass
torch.inference_mode.__enter__ = empty_enter
"""
orig_inference_mode = torch.inference_mode
@contextmanager
def fake_inference_mode(mode):
yield
#try:
# with torch.inference_mode(False):
# yield
#finally:
# pass
torch.inference_mode = fake_inference_mode
"""
orig_torch_generator = torch.Generator
def new_torch_Generator(device="cpu"):
if torch_directml_installed and type(device) is torch.device and device.type == torch_directml_device_type:
device = "cpu"
return orig_torch_generator(device)
torch.Generator = new_torch_Generator
# Modified ControlNetModel with encoder_attention_mask argument added

View File

@ -98,7 +98,9 @@ dependencies = [
"dev" = [
"pudb",
]
"test" = ["pytest>6.0.0", "pytest-cov"]
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
"torch-directml" = ["torch-directml"]
"xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'",