From c06e628ff6cf14c56ac43990e43e789995b10559 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 1 Aug 2023 05:27:38 +0300 Subject: [PATCH] Some overrides to make it work --- .../stable_diffusion/diffusers_pipeline.py | 6 +++ invokeai/backend/util/hotfixes.py | 37 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 624d47ff64..ba87732516 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -339,6 +339,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): """ if xformers is available, use it, otherwise use sliced attention. """ + # TODO: use DML_DEVICE? + if self.device.type == "privateuseone": + self.enable_attention_slicing(slice_size="max") + #self.unet.set_default_attn_processor() + return + config = InvokeAIAppConfig.get_config() if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers: self.enable_xformers_memory_efficient_attention() diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 4710682ac1..e1ad793002 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -18,6 +18,43 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel import diffusers from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module +try: + import torch_directml + torch_directml_installed = True + torch_directml_device_type = torch_directml.device().type # privateuseone +except: + torch_directml_installed = False + +if torch_directml_installed: + from contextlib import contextmanager + #torch.inference_mode = torch.no_grad + + def empty_enter(self): + pass + + torch.inference_mode.__enter__ = empty_enter + """ + orig_inference_mode = torch.inference_mode + @contextmanager + def fake_inference_mode(mode): + yield + #try: + # with torch.inference_mode(False): + # yield + #finally: + # pass + + torch.inference_mode = fake_inference_mode + """ + + + orig_torch_generator = torch.Generator + def new_torch_Generator(device="cpu"): + if torch_directml_installed and type(device) is torch.device and device.type == torch_directml_device_type: + device = "cpu" + return orig_torch_generator(device) + torch.Generator = new_torch_Generator + # Modified ControlNetModel with encoder_attention_mask argument added