diff --git a/ldm/generate.py b/ldm/generate.py index 7695c3a0bc..413a1e25cb 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -650,6 +650,8 @@ class Generate: def clear_cuda_cache(self): if self._has_cuda(): self.gather_cuda_stats() + # Run garbage collection prior to emptying the CUDA cache + gc.collect() torch.cuda.empty_cache() def clear_cuda_stats(self):