mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for "balanced" attention slice size
This commit is contained in:
@ -12,3 +12,4 @@ from .devices import (
|
||||
)
|
||||
from .log import write_log
|
||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir
|
||||
from .attention import auto_detect_slice_size
|
||||
|
24
invokeai/backend/util/attention.py
Normal file
24
invokeai/backend/util/attention.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
|
||||
"""
|
||||
Utility routine used for autodetection of optimal slice size
|
||||
for attention mechanism.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def auto_detect_slice_size(latents: torch.Tensor) -> str:
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):
|
||||
return "max"
|
||||
elif torch.backends.mps.is_available():
|
||||
return "max"
|
||||
else:
|
||||
return "balanced"
|
Reference in New Issue
Block a user