Some overrides to make it work

This commit is contained in:
Sergey Borisov 2023-08-01 05:27:38 +03:00
parent 125bd06e6b
commit c06e628ff6
2 changed files with 43 additions and 0 deletions

View File

@ -339,6 +339,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
# TODO: use DML_DEVICE?
if self.device.type == "privateuseone":
self.enable_attention_slicing(slice_size="max")
#self.unet.set_default_attn_processor()
return
config = InvokeAIAppConfig.get_config()
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
self.enable_xformers_memory_efficient_attention()

View File

@ -18,6 +18,43 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
try:
import torch_directml
torch_directml_installed = True
torch_directml_device_type = torch_directml.device().type # privateuseone
except:
torch_directml_installed = False
if torch_directml_installed:
from contextlib import contextmanager
#torch.inference_mode = torch.no_grad
def empty_enter(self):
pass
torch.inference_mode.__enter__ = empty_enter
"""
orig_inference_mode = torch.inference_mode
@contextmanager
def fake_inference_mode(mode):
yield
#try:
# with torch.inference_mode(False):
# yield
#finally:
# pass
torch.inference_mode = fake_inference_mode
"""
orig_torch_generator = torch.Generator
def new_torch_Generator(device="cpu"):
if torch_directml_installed and type(device) is torch.device and device.type == torch_directml_device_type:
device = "cpu"
return orig_torch_generator(device)
torch.Generator = new_torch_Generator
# Modified ControlNetModel with encoder_attention_mask argument added