mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
In exception handlers, clear the torch CUDA cache (if we're using CUDA) to free up memory for other programs using the GPU and to reduce fragmentation. (#2549)
This commit is contained in:
@ -1208,12 +1208,18 @@ class InvokeAIWebServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
# Clear the CUDA cache on an exception
|
||||||
|
self.empty_cuda_cache()
|
||||||
self.socketio.emit("processingCanceled")
|
self.socketio.emit("processingCanceled")
|
||||||
raise
|
raise
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
|
# Clear the CUDA cache on an exception
|
||||||
|
self.empty_cuda_cache()
|
||||||
self.socketio.emit("processingCanceled")
|
self.socketio.emit("processingCanceled")
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Clear the CUDA cache on an exception
|
||||||
|
self.empty_cuda_cache()
|
||||||
print(e)
|
print(e)
|
||||||
self.socketio.emit("error", {"message": (str(e))})
|
self.socketio.emit("error", {"message": (str(e))})
|
||||||
print("\n")
|
print("\n")
|
||||||
@ -1221,6 +1227,12 @@ class InvokeAIWebServer:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
|
def empty_cuda_cache(self):
|
||||||
|
if self.generate.device.type == "cuda":
|
||||||
|
import torch.cuda
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def parameters_to_generated_image_metadata(self, parameters):
|
def parameters_to_generated_image_metadata(self, parameters):
|
||||||
try:
|
try:
|
||||||
# top-level metadata minus `image` or `images`
|
# top-level metadata minus `image` or `images`
|
||||||
|
@ -565,11 +565,19 @@ class Generate:
|
|||||||
image_callback = image_callback)
|
image_callback = image_callback)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
# Clear the CUDA cache on an exception
|
||||||
|
if self._has_cuda():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if catch_interrupts:
|
if catch_interrupts:
|
||||||
print('**Interrupted** Partial results will be returned.')
|
print('**Interrupted** Partial results will be returned.')
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
# Clear the CUDA cache on an exception
|
||||||
|
if self._has_cuda():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
print('>> Could not generate image.')
|
print('>> Could not generate image.')
|
||||||
|
|
||||||
|
@ -65,6 +65,10 @@ class Txt2Img2Img(Generator):
|
|||||||
mode="bilinear"
|
mode="bilinear"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Free up memory from the last generation.
|
||||||
|
if self.model.device.type == 'cuda':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
second_pass_noise = self.get_noise_like(resized_latents)
|
second_pass_noise = self.get_noise_like(resized_latents)
|
||||||
|
|
||||||
verbosity = get_verbosity()
|
verbosity = get_verbosity()
|
||||||
|
Reference in New Issue
Block a user