From edc8f5fb6f96e24ed01b1a1514733f9753ad35ab Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 1 Sep 2023 01:02:47 +0300 Subject: [PATCH] Refactor attention --- .../stable_diffusion/diffusers_pipeline.py | 62 +++++++------------ 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 89a22bb416..ae609e5598 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -224,7 +224,6 @@ class StableDiffusionGeneratorPipeline: fn_recursive_set_mem_eff(module) - def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size) @@ -236,57 +235,44 @@ class StableDiffusionGeneratorPipeline: config = InvokeAIAppConfig.get_config() if config.attention_type == "xformers": self.set_use_memory_efficient_attention_xformers(model, True) - return + elif config.attention_type == "sliced": slice_size = config.attention_slice_size if slice_size == "auto": slice_size = auto_detect_slice_size(latents) - elif slice_size == "balanced": + + if slice_size == "balanced": slice_size = "auto" self.set_attention_slice(model, slice_size=slice_size) - return + elif config.attention_type == "normal": self.set_attention_slice(model, slice_size=None) - return + elif config.attention_type == "torch-sdp": if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") model.set_attn_processor(AttnProcessor2_0()) - return - # the remainder if this code is called when attention_type=='auto' - if model.device.type == "cuda": - if is_xformers_available() and not config.disable_xformers: - self.set_use_memory_efficient_attention_xformers(model, True) - return - elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): - model.set_attn_processor(AttnProcessor2_0()) - return + else: # auto + if model.device.type == "cuda": + if is_xformers_available() and not config.disable_xformers: + self.set_use_memory_efficient_attention_xformers(model, True) - if model.device.type == "cpu" or model.device.type == "mps": - mem_free = psutil.virtual_memory().free - elif model.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device)) - else: - raise ValueError(f"unrecognized device {model.device}") - # input tensor of [1, 4, h/8, w/8] - # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 - max_size_required_for_baddbmm = ( - 16 - * latents.size(dim=2) - * latents.size(dim=3) - * latents.size(dim=2) - * latents.size(dim=3) - * bytes_per_element_needed_for_baddbmm_duplication - ) - if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code - self.set_attention_slice(model, slice_size="max") - elif torch.backends.mps.is_available(): - # diffusers recommends always enabling for mps - self.set_attention_slice(model, slice_size="max") - else: - self.set_attention_slice(model, slice_size=None) + elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): + model.set_attn_processor(AttnProcessor2_0()) + + else: + if model.device.type == "cpu" or model.device.type == "mps": + mem_free = psutil.virtual_memory().free + elif model.device.type == "cuda": + mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device)) + else: + raise ValueError(f"unrecognized device {model.device}") + + slice_size = auto_detect_slice_size(latents) + if slice_size == "balanced": + slice_size = "auto" + self.set_attention_slice(model, slice_size=slice_size) def latents_from_embeddings( self,