mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor CUDA cache clearing to add statistical reporting. (#2553)
This commit is contained in:
parent
8ab66a211c
commit
28b40bebbe
@ -344,6 +344,7 @@ class Generate:
|
|||||||
|
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
|
self.clear_cuda_stats()
|
||||||
"""
|
"""
|
||||||
ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
|
ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
|
||||||
It takes the following arguments:
|
It takes the following arguments:
|
||||||
@ -548,6 +549,7 @@ class Generate:
|
|||||||
inpaint_width = inpaint_width,
|
inpaint_width = inpaint_width,
|
||||||
enable_image_debugging = enable_image_debugging,
|
enable_image_debugging = enable_image_debugging,
|
||||||
free_gpu_mem=self.free_gpu_mem,
|
free_gpu_mem=self.free_gpu_mem,
|
||||||
|
clear_cuda_cache=self.clear_cuda_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
if init_color:
|
if init_color:
|
||||||
@ -566,8 +568,7 @@ class Generate:
|
|||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# Clear the CUDA cache on an exception
|
# Clear the CUDA cache on an exception
|
||||||
if self._has_cuda():
|
self.clear_cuda_cache()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if catch_interrupts:
|
if catch_interrupts:
|
||||||
print('**Interrupted** Partial results will be returned.')
|
print('**Interrupted** Partial results will be returned.')
|
||||||
@ -575,8 +576,7 @@ class Generate:
|
|||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# Clear the CUDA cache on an exception
|
# Clear the CUDA cache on an exception
|
||||||
if self._has_cuda():
|
self.clear_cuda_cache()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
print('>> Could not generate image.')
|
print('>> Could not generate image.')
|
||||||
@ -587,22 +587,42 @@ class Generate:
|
|||||||
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
|
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
|
||||||
toc - tic)
|
toc - tic)
|
||||||
)
|
)
|
||||||
|
self.print_cuda_stats()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def clear_cuda_cache(self):
|
||||||
|
if self._has_cuda():
|
||||||
|
self.max_memory_allocated = max(
|
||||||
|
self.max_memory_allocated,
|
||||||
|
torch.cuda.max_memory_allocated()
|
||||||
|
)
|
||||||
|
self.memory_allocated = max(
|
||||||
|
self.memory_allocated,
|
||||||
|
torch.cuda.memory_allocated()
|
||||||
|
)
|
||||||
|
self.session_peakmem = max(
|
||||||
|
self.session_peakmem,
|
||||||
|
torch.cuda.max_memory_allocated()
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def clear_cuda_stats(self):
|
||||||
|
self.max_memory_allocated = 0
|
||||||
|
self.memory_allocated = 0
|
||||||
|
|
||||||
|
def print_cuda_stats(self):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(
|
print(
|
||||||
'>> Max VRAM used for this generation:',
|
'>> Max VRAM used for this generation:',
|
||||||
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
|
'%4.2fG.' % (self.max_memory_allocated / 1e9),
|
||||||
'Current VRAM utilization:',
|
'Current VRAM utilization:',
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
'%4.2fG' % (self.memory_allocated / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session_peakmem = max(
|
|
||||||
self.session_peakmem, torch.cuda.max_memory_allocated()
|
|
||||||
)
|
|
||||||
print(
|
print(
|
||||||
'>> Max VRAM used since script start: ',
|
'>> Max VRAM used since script start: ',
|
||||||
'%4.2fG' % (self.session_peakmem / 1e9),
|
'%4.2fG' % (self.session_peakmem / 1e9),
|
||||||
)
|
)
|
||||||
return results
|
|
||||||
|
|
||||||
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
||||||
# in a nice harmonized call signature. For now we have a bunch of if/elses!
|
# in a nice harmonized call signature. For now we have a bunch of if/elses!
|
||||||
|
@ -123,8 +123,9 @@ class Generator:
|
|||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
|
|
||||||
# Free up memory from the last generation.
|
# Free up memory from the last generation.
|
||||||
if self.model.device.type == 'cuda':
|
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
|
||||||
torch.cuda.empty_cache()
|
if clear_cuda_cache is not None:
|
||||||
|
clear_cuda_cache()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -66,8 +66,9 @@ class Txt2Img2Img(Generator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Free up memory from the last generation.
|
# Free up memory from the last generation.
|
||||||
if self.model.device.type == 'cuda':
|
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
|
||||||
torch.cuda.empty_cache()
|
if clear_cuda_cache is not None:
|
||||||
|
clear_cuda_cache()
|
||||||
|
|
||||||
second_pass_noise = self.get_noise_like(resized_latents)
|
second_pass_noise = self.get_noise_like(resized_latents)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user