Refactor attention

This commit is contained in:
Sergey Borisov 2023-09-01 01:02:47 +03:00
parent 6bb657b3f3
commit edc8f5fb6f

View File

@ -224,7 +224,6 @@ class StableDiffusionGeneratorPipeline:
fn_recursive_set_mem_eff(module) fn_recursive_set_mem_eff(module)
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]): def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size) module.set_attention_slice(slice_size)
@ -236,57 +235,44 @@ class StableDiffusionGeneratorPipeline:
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
if config.attention_type == "xformers": if config.attention_type == "xformers":
self.set_use_memory_efficient_attention_xformers(model, True) self.set_use_memory_efficient_attention_xformers(model, True)
return
elif config.attention_type == "sliced": elif config.attention_type == "sliced":
slice_size = config.attention_slice_size slice_size = config.attention_slice_size
if slice_size == "auto": if slice_size == "auto":
slice_size = auto_detect_slice_size(latents) slice_size = auto_detect_slice_size(latents)
elif slice_size == "balanced":
if slice_size == "balanced":
slice_size = "auto" slice_size = "auto"
self.set_attention_slice(model, slice_size=slice_size) self.set_attention_slice(model, slice_size=slice_size)
return
elif config.attention_type == "normal": elif config.attention_type == "normal":
self.set_attention_slice(model, slice_size=None) self.set_attention_slice(model, slice_size=None)
return
elif config.attention_type == "torch-sdp": elif config.attention_type == "torch-sdp":
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
model.set_attn_processor(AttnProcessor2_0()) model.set_attn_processor(AttnProcessor2_0())
return
# the remainder if this code is called when attention_type=='auto' else: # auto
if model.device.type == "cuda": if model.device.type == "cuda":
if is_xformers_available() and not config.disable_xformers: if is_xformers_available() and not config.disable_xformers:
self.set_use_memory_efficient_attention_xformers(model, True) self.set_use_memory_efficient_attention_xformers(model, True)
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
model.set_attn_processor(AttnProcessor2_0())
return
if model.device.type == "cpu" or model.device.type == "mps": elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
mem_free = psutil.virtual_memory().free model.set_attn_processor(AttnProcessor2_0())
elif model.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device)) else:
else: if model.device.type == "cpu" or model.device.type == "mps":
raise ValueError(f"unrecognized device {model.device}") mem_free = psutil.virtual_memory().free
# input tensor of [1, 4, h/8, w/8] elif model.device.type == "cuda":
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)] mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 else:
max_size_required_for_baddbmm = ( raise ValueError(f"unrecognized device {model.device}")
16
* latents.size(dim=2) slice_size = auto_detect_slice_size(latents)
* latents.size(dim=3) if slice_size == "balanced":
* latents.size(dim=2) slice_size = "auto"
* latents.size(dim=3) self.set_attention_slice(model, slice_size=slice_size)
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.set_attention_slice(model, slice_size="max")
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
self.set_attention_slice(model, slice_size="max")
else:
self.set_attention_slice(model, slice_size=None)
def latents_from_embeddings( def latents_from_embeddings(
self, self,