Merge branch 'main' into 2.3.0rc5

This commit is contained in:
Lincoln Stein 2023-02-06 12:55:47 -05:00 committed by GitHub
commit bde6e96800
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 14 deletions

View File

@ -344,6 +344,7 @@ class Generate:
**args,
): # eat up additional cruft
self.clear_cuda_stats()
"""
ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments:
@ -548,6 +549,7 @@ class Generate:
inpaint_width = inpaint_width,
enable_image_debugging = enable_image_debugging,
free_gpu_mem=self.free_gpu_mem,
clear_cuda_cache=self.clear_cuda_cache
)
if init_color:
@ -566,8 +568,7 @@ class Generate:
except KeyboardInterrupt:
# Clear the CUDA cache on an exception
if self._has_cuda():
torch.cuda.empty_cache()
self.clear_cuda_cache()
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
@ -575,8 +576,7 @@ class Generate:
raise KeyboardInterrupt
except RuntimeError:
# Clear the CUDA cache on an exception
if self._has_cuda():
torch.cuda.empty_cache()
self.clear_cuda_cache()
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
@ -587,22 +587,42 @@ class Generate:
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
toc - tic)
)
self.print_cuda_stats()
return results
def clear_cuda_cache(self):
if self._has_cuda():
self.max_memory_allocated = max(
self.max_memory_allocated,
torch.cuda.max_memory_allocated()
)
self.memory_allocated = max(
self.memory_allocated,
torch.cuda.memory_allocated()
)
self.session_peakmem = max(
self.session_peakmem,
torch.cuda.max_memory_allocated()
)
torch.cuda.empty_cache()
def clear_cuda_stats(self):
self.max_memory_allocated = 0
self.memory_allocated = 0
def print_cuda_stats(self):
if self._has_cuda():
print(
'>> Max VRAM used for this generation:',
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
'%4.2fG.' % (self.max_memory_allocated / 1e9),
'Current VRAM utilization:',
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
'%4.2fG' % (self.memory_allocated / 1e9),
)
self.session_peakmem = max(
self.session_peakmem, torch.cuda.max_memory_allocated()
)
print(
'>> Max VRAM used since script start: ',
'%4.2fG' % (self.session_peakmem / 1e9),
)
return results
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
# in a nice harmonized call signature. For now we have a bunch of if/elses!

View File

@ -123,8 +123,9 @@ class Generator:
seed = self.new_seed()
# Free up memory from the last generation.
if self.model.device.type == 'cuda':
torch.cuda.empty_cache()
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
if clear_cuda_cache is not None:
clear_cuda_cache()
return results

View File

@ -66,8 +66,9 @@ class Txt2Img2Img(Generator):
)
# Free up memory from the last generation.
if self.model.device.type == 'cuda':
torch.cuda.empty_cache()
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
if clear_cuda_cache is not None:
clear_cuda_cache()
second_pass_noise = self.get_noise_like(resized_latents)