In exception handlers, clear the torch CUDA cache (if we're using CUDA) to free up memory for other programs using the GPU and to reduce fragmentation. (#2549)

This commit is contained in:
Jonathan 2023-02-06 09:33:24 -06:00 committed by GitHub
parent a485d45400
commit 2432adb38f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 1 deletions

View File

@ -1208,12 +1208,18 @@ class InvokeAIWebServer:
)
except KeyboardInterrupt:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
self.socketio.emit("processingCanceled")
raise
except CanceledException:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
self.socketio.emit("processingCanceled")
pass
except Exception as e:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
print(e)
self.socketio.emit("error", {"message": (str(e))})
print("\n")
@ -1221,6 +1227,12 @@ class InvokeAIWebServer:
traceback.print_exc()
print("\n")
def empty_cuda_cache(self):
if self.generate.device.type == "cuda":
import torch.cuda
torch.cuda.empty_cache()
def parameters_to_generated_image_metadata(self, parameters):
try:
# top-level metadata minus `image` or `images`

View File

@ -211,7 +211,7 @@ class Generate:
print('>> xformers memory-efficient attention is available but disabled')
else:
print('>> xformers not installed')
# model caching system for fast switching
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
# don't accept invalid models
@ -565,11 +565,19 @@ class Generate:
image_callback = image_callback)
except KeyboardInterrupt:
# Clear the CUDA cache on an exception
if self._has_cuda():
torch.cuda.empty_cache()
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
else:
raise KeyboardInterrupt
except RuntimeError:
# Clear the CUDA cache on an exception
if self._has_cuda():
torch.cuda.empty_cache()
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')

View File

@ -65,6 +65,10 @@ class Txt2Img2Img(Generator):
mode="bilinear"
)
# Free up memory from the last generation.
if self.model.device.type == 'cuda':
torch.cuda.empty_cache()
second_pass_noise = self.get_noise_like(resized_latents)
verbosity = get_verbosity()