Strategize slicing based on free [V]RAM (#2572)

Strategize slicing based on free [V]RAM when not using xformers. Free [V]RAM is evaluated at every generation. When there's enough memory, the entire generation occurs without slicing. If there is not enough free memory, we use diffusers' sliced attention.
This commit is contained in:
Jonathan
2023-02-12 12:24:15 -06:00
committed by GitHub
parent 7c86130a3d
commit 9eed1919c2
2 changed files with 30 additions and 9 deletions

View File

@ -301,10 +301,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager
)
self._enable_memory_efficient_attention()
def _enable_memory_efficient_attention(self):
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
@ -317,7 +315,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
self.enable_attention_slicing(slice_size='max')
if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
else:
raise ValueError(f"unrecognized device {device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = \
16 * \
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
bytes_per_element_needed_for_baddbmm_duplication
if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size='max')
else:
self.disable_attention_slicing()
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
@ -377,6 +392,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: