mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor attention
This commit is contained in:
parent
6bb657b3f3
commit
edc8f5fb6f
@ -224,7 +224,6 @@ class StableDiffusionGeneratorPipeline:
|
||||
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
|
||||
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size)
|
||||
@ -236,57 +235,44 @@ class StableDiffusionGeneratorPipeline:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if config.attention_type == "xformers":
|
||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||
return
|
||||
|
||||
elif config.attention_type == "sliced":
|
||||
slice_size = config.attention_slice_size
|
||||
if slice_size == "auto":
|
||||
slice_size = auto_detect_slice_size(latents)
|
||||
elif slice_size == "balanced":
|
||||
|
||||
if slice_size == "balanced":
|
||||
slice_size = "auto"
|
||||
self.set_attention_slice(model, slice_size=slice_size)
|
||||
return
|
||||
|
||||
elif config.attention_type == "normal":
|
||||
self.set_attention_slice(model, slice_size=None)
|
||||
return
|
||||
|
||||
elif config.attention_type == "torch-sdp":
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
return
|
||||
|
||||
# the remainder if this code is called when attention_type=='auto'
|
||||
if model.device.type == "cuda":
|
||||
if is_xformers_available() and not config.disable_xformers:
|
||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||
return
|
||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
return
|
||||
else: # auto
|
||||
if model.device.type == "cuda":
|
||||
if is_xformers_available() and not config.disable_xformers:
|
||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||
|
||||
if model.device.type == "cpu" or model.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif model.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {model.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.set_attention_slice(model, slice_size="max")
|
||||
elif torch.backends.mps.is_available():
|
||||
# diffusers recommends always enabling for mps
|
||||
self.set_attention_slice(model, slice_size="max")
|
||||
else:
|
||||
self.set_attention_slice(model, slice_size=None)
|
||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
|
||||
else:
|
||||
if model.device.type == "cpu" or model.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif model.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {model.device}")
|
||||
|
||||
slice_size = auto_detect_slice_size(latents)
|
||||
if slice_size == "balanced":
|
||||
slice_size = "auto"
|
||||
self.set_attention_slice(model, slice_size=slice_size)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user