diff --git a/ldm/generate.py b/ldm/generate.py index 9101ac3f01..fa4e603499 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -344,6 +344,7 @@ class Generate: **args, ): # eat up additional cruft + self.clear_cuda_stats() """ ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() It takes the following arguments: @@ -548,6 +549,7 @@ class Generate: inpaint_width = inpaint_width, enable_image_debugging = enable_image_debugging, free_gpu_mem=self.free_gpu_mem, + clear_cuda_cache=self.clear_cuda_cache ) if init_color: @@ -566,8 +568,7 @@ class Generate: except KeyboardInterrupt: # Clear the CUDA cache on an exception - if self._has_cuda(): - torch.cuda.empty_cache() + self.clear_cuda_cache() if catch_interrupts: print('**Interrupted** Partial results will be returned.') @@ -575,8 +576,7 @@ class Generate: raise KeyboardInterrupt except RuntimeError: # Clear the CUDA cache on an exception - if self._has_cuda(): - torch.cuda.empty_cache() + self.clear_cuda_cache() print(traceback.format_exc(), file=sys.stderr) print('>> Could not generate image.') @@ -587,22 +587,42 @@ class Generate: f'>> {len(results)} image(s) generated in', '%4.2fs' % ( toc - tic) ) + self.print_cuda_stats() + return results + + def clear_cuda_cache(self): + if self._has_cuda(): + self.max_memory_allocated = max( + self.max_memory_allocated, + torch.cuda.max_memory_allocated() + ) + self.memory_allocated = max( + self.memory_allocated, + torch.cuda.memory_allocated() + ) + self.session_peakmem = max( + self.session_peakmem, + torch.cuda.max_memory_allocated() + ) + torch.cuda.empty_cache() + + def clear_cuda_stats(self): + self.max_memory_allocated = 0 + self.memory_allocated = 0 + + def print_cuda_stats(self): if self._has_cuda(): print( '>> Max VRAM used for this generation:', - '%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9), + '%4.2fG.' % (self.max_memory_allocated / 1e9), 'Current VRAM utilization:', - '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), + '%4.2fG' % (self.memory_allocated / 1e9), ) - self.session_peakmem = max( - self.session_peakmem, torch.cuda.max_memory_allocated() - ) print( '>> Max VRAM used since script start: ', '%4.2fG' % (self.session_peakmem / 1e9), ) - return results # this needs to be generalized to all sorts of postprocessors, which should be wrapped # in a nice harmonized call signature. For now we have a bunch of if/elses! diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index f30ab256ae..048361c057 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -123,8 +123,9 @@ class Generator: seed = self.new_seed() # Free up memory from the last generation. - if self.model.device.type == 'cuda': - torch.cuda.empty_cache() + clear_cuda_cache = kwargs['clear_cuda_cache'] or None + if clear_cuda_cache is not None: + clear_cuda_cache() return results diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 0e9493aa44..9632a8d4b0 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -66,8 +66,9 @@ class Txt2Img2Img(Generator): ) # Free up memory from the last generation. - if self.model.device.type == 'cuda': - torch.cuda.empty_cache() + clear_cuda_cache = kwargs['clear_cuda_cache'] or None + if clear_cuda_cache is not None: + clear_cuda_cache() second_pass_noise = self.get_noise_like(resized_latents)