Refactor CUDA cache clearing to add statistical reporting. (#2553)

This commit is contained in:
Jonathan
2023-02-06 11:53:30 -06:00
committed by GitHub
parent 8ab66a211c
commit 28b40bebbe
3 changed files with 36 additions and 14 deletions

View File

@ -66,8 +66,9 @@ class Txt2Img2Img(Generator):
)
# Free up memory from the last generation.
if self.model.device.type == 'cuda':
torch.cuda.empty_cache()
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
if clear_cuda_cache is not None:
clear_cuda_cache()
second_pass_noise = self.get_noise_like(resized_latents)