From 9eed1919c2071f9199996df747c8638c4a75e8fb Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Sun, 12 Feb 2023 12:24:15 -0600 Subject: [PATCH] Strategize slicing based on free [V]RAM (#2572) Strategize slicing based on free [V]RAM when not using xformers. Free [V]RAM is evaluated at every generation. When there's enough memory, the entire generation occurs without slicing. If there is not enough free memory, we use diffusers' sliced attention. --- ldm/generate.py | 15 +++++++++----- ldm/invoke/generator/diffusers_pipeline.py | 24 ++++++++++++++++++---- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 32a6a929a8..8cb3058694 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -223,7 +223,7 @@ class Generate: self.model_name = model or fallback # for VRAM usage statistics - self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None + self.session_peakmem = torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None transformers.logging.set_verbosity_error() # gets rid of annoying messages about random seed @@ -592,20 +592,24 @@ class Generate: self.print_cuda_stats() return results - def clear_cuda_cache(self): + def gather_cuda_stats(self): if self._has_cuda(): self.max_memory_allocated = max( self.max_memory_allocated, - torch.cuda.max_memory_allocated() + torch.cuda.max_memory_allocated(self.device) ) self.memory_allocated = max( self.memory_allocated, - torch.cuda.memory_allocated() + torch.cuda.memory_allocated(self.device) ) self.session_peakmem = max( self.session_peakmem, - torch.cuda.max_memory_allocated() + torch.cuda.max_memory_allocated(self.device) ) + + def clear_cuda_cache(self): + if self._has_cuda(): + self.gather_cuda_stats() torch.cuda.empty_cache() def clear_cuda_stats(self): @@ -614,6 +618,7 @@ class Generate: def print_cuda_stats(self): if self._has_cuda(): + self.gather_cuda_stats() print( '>> Max VRAM used for this generation:', '%4.2fG.' % (self.max_memory_allocated / 1e9), diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index f065a0ec2d..24626247cf 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -301,10 +301,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): textual_inversion_manager=self.textual_inversion_manager ) - self._enable_memory_efficient_attention() - - def _enable_memory_efficient_attention(self): + def _adjust_memory_efficient_attention(self, latents: Torch.tensor): """ if xformers is available, use it, otherwise use sliced attention. """ @@ -317,7 +315,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. pass else: - self.enable_attention_slicing(slice_size='max') + if self.device.type == 'cpu' or self.device.type == 'mps': + mem_free = psutil.virtual_memory().free + elif self.device.type == 'cuda': + mem_free, _ = torch.cuda.mem_get_info(self.device) + else: + raise ValueError(f"unrecognized device {device}") + # input tensor of [1, 4, h/8, w/8] + # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] + bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 + max_size_required_for_baddbmm = \ + 16 * \ + latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \ + bytes_per_element_needed_for_baddbmm_duplication + if max_size_required_for_baddbmm > (mem_free * 3.3 / 4.0): # 3.3 / 4.0 is from old Invoke code + self.enable_attention_slicing(slice_size='max') + else: + self.disable_attention_slicing() + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, @@ -377,6 +392,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, run_id: str = None, additional_guidance: List[Callable] = None): + self._adjust_memory_efficient_attention(latents) if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) if additional_guidance is None: