Strategize slicing based on free [V]RAM (#2572)

Strategize slicing based on free [V]RAM when not using xformers. Free [V]RAM is evaluated at every generation. When there's enough memory, the entire generation occurs without slicing. If there is not enough free memory, we use diffusers' sliced attention.
This commit is contained in:
Jonathan 2023-02-12 12:24:15 -06:00 committed by GitHub
parent 7c86130a3d
commit 9eed1919c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 9 deletions

View File

@ -223,7 +223,7 @@ class Generate:
self.model_name = model or fallback
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None
transformers.logging.set_verbosity_error()
# gets rid of annoying messages about random seed
@ -592,20 +592,24 @@ class Generate:
self.print_cuda_stats()
return results
def clear_cuda_cache(self):
def gather_cuda_stats(self):
if self._has_cuda():
self.max_memory_allocated = max(
self.max_memory_allocated,
torch.cuda.max_memory_allocated()
torch.cuda.max_memory_allocated(self.device)
)
self.memory_allocated = max(
self.memory_allocated,
torch.cuda.memory_allocated()
torch.cuda.memory_allocated(self.device)
)
self.session_peakmem = max(
self.session_peakmem,
torch.cuda.max_memory_allocated()
torch.cuda.max_memory_allocated(self.device)
)
def clear_cuda_cache(self):
if self._has_cuda():
self.gather_cuda_stats()
torch.cuda.empty_cache()
def clear_cuda_stats(self):
@ -614,6 +618,7 @@ class Generate:
def print_cuda_stats(self):
if self._has_cuda():
self.gather_cuda_stats()
print(
'>> Max VRAM used for this generation:',
'%4.2fG.' % (self.max_memory_allocated / 1e9),

View File

@ -301,10 +301,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager
)
self._enable_memory_efficient_attention()
def _enable_memory_efficient_attention(self):
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
@ -317,7 +315,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
self.enable_attention_slicing(slice_size='max')
if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
else:
raise ValueError(f"unrecognized device {device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = \
16 * \
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
bytes_per_element_needed_for_baddbmm_duplication
if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size='max')
else:
self.disable_attention_slicing()
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
@ -377,6 +392,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: