mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor attention
This commit is contained in:
parent
6bb657b3f3
commit
edc8f5fb6f
@ -224,7 +224,6 @@ class StableDiffusionGeneratorPipeline:
|
|||||||
|
|
||||||
fn_recursive_set_mem_eff(module)
|
fn_recursive_set_mem_eff(module)
|
||||||
|
|
||||||
|
|
||||||
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
|
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
|
||||||
if hasattr(module, "set_attention_slice"):
|
if hasattr(module, "set_attention_slice"):
|
||||||
module.set_attention_slice(slice_size)
|
module.set_attention_slice(slice_size)
|
||||||
@ -236,57 +235,44 @@ class StableDiffusionGeneratorPipeline:
|
|||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
if config.attention_type == "xformers":
|
if config.attention_type == "xformers":
|
||||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||||
return
|
|
||||||
elif config.attention_type == "sliced":
|
elif config.attention_type == "sliced":
|
||||||
slice_size = config.attention_slice_size
|
slice_size = config.attention_slice_size
|
||||||
if slice_size == "auto":
|
if slice_size == "auto":
|
||||||
slice_size = auto_detect_slice_size(latents)
|
slice_size = auto_detect_slice_size(latents)
|
||||||
elif slice_size == "balanced":
|
|
||||||
|
if slice_size == "balanced":
|
||||||
slice_size = "auto"
|
slice_size = "auto"
|
||||||
self.set_attention_slice(model, slice_size=slice_size)
|
self.set_attention_slice(model, slice_size=slice_size)
|
||||||
return
|
|
||||||
elif config.attention_type == "normal":
|
elif config.attention_type == "normal":
|
||||||
self.set_attention_slice(model, slice_size=None)
|
self.set_attention_slice(model, slice_size=None)
|
||||||
return
|
|
||||||
elif config.attention_type == "torch-sdp":
|
elif config.attention_type == "torch-sdp":
|
||||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
model.set_attn_processor(AttnProcessor2_0())
|
model.set_attn_processor(AttnProcessor2_0())
|
||||||
return
|
|
||||||
|
|
||||||
# the remainder if this code is called when attention_type=='auto'
|
else: # auto
|
||||||
if model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
if is_xformers_available() and not config.disable_xformers:
|
if is_xformers_available() and not config.disable_xformers:
|
||||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||||
return
|
|
||||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
model.set_attn_processor(AttnProcessor2_0())
|
model.set_attn_processor(AttnProcessor2_0())
|
||||||
return
|
|
||||||
|
|
||||||
|
else:
|
||||||
if model.device.type == "cpu" or model.device.type == "mps":
|
if model.device.type == "cpu" or model.device.type == "mps":
|
||||||
mem_free = psutil.virtual_memory().free
|
mem_free = psutil.virtual_memory().free
|
||||||
elif model.device.type == "cuda":
|
elif model.device.type == "cuda":
|
||||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
|
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unrecognized device {model.device}")
|
raise ValueError(f"unrecognized device {model.device}")
|
||||||
# input tensor of [1, 4, h/8, w/8]
|
|
||||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
slice_size = auto_detect_slice_size(latents)
|
||||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
if slice_size == "balanced":
|
||||||
max_size_required_for_baddbmm = (
|
slice_size = "auto"
|
||||||
16
|
self.set_attention_slice(model, slice_size=slice_size)
|
||||||
* latents.size(dim=2)
|
|
||||||
* latents.size(dim=3)
|
|
||||||
* latents.size(dim=2)
|
|
||||||
* latents.size(dim=3)
|
|
||||||
* bytes_per_element_needed_for_baddbmm_duplication
|
|
||||||
)
|
|
||||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
|
||||||
self.set_attention_slice(model, slice_size="max")
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
# diffusers recommends always enabling for mps
|
|
||||||
self.set_attention_slice(model, slice_size="max")
|
|
||||||
else:
|
|
||||||
self.set_attention_slice(model, slice_size=None)
|
|
||||||
|
|
||||||
def latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user