mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Some overrides to make it work
This commit is contained in:
parent
125bd06e6b
commit
c06e628ff6
@ -339,6 +339,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
# TODO: use DML_DEVICE?
|
||||
if self.device.type == "privateuseone":
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
#self.unet.set_default_attn_processor()
|
||||
return
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
|
@ -18,6 +18,43 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
import diffusers
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
|
||||
try:
|
||||
import torch_directml
|
||||
torch_directml_installed = True
|
||||
torch_directml_device_type = torch_directml.device().type # privateuseone
|
||||
except:
|
||||
torch_directml_installed = False
|
||||
|
||||
if torch_directml_installed:
|
||||
from contextlib import contextmanager
|
||||
#torch.inference_mode = torch.no_grad
|
||||
|
||||
def empty_enter(self):
|
||||
pass
|
||||
|
||||
torch.inference_mode.__enter__ = empty_enter
|
||||
"""
|
||||
orig_inference_mode = torch.inference_mode
|
||||
@contextmanager
|
||||
def fake_inference_mode(mode):
|
||||
yield
|
||||
#try:
|
||||
# with torch.inference_mode(False):
|
||||
# yield
|
||||
#finally:
|
||||
# pass
|
||||
|
||||
torch.inference_mode = fake_inference_mode
|
||||
"""
|
||||
|
||||
|
||||
orig_torch_generator = torch.Generator
|
||||
def new_torch_Generator(device="cpu"):
|
||||
if torch_directml_installed and type(device) is torch.device and device.type == torch_directml_device_type:
|
||||
device = "cpu"
|
||||
return orig_torch_generator(device)
|
||||
torch.Generator = new_torch_Generator
|
||||
|
||||
# Modified ControlNetModel with encoder_attention_mask argument added
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user