recover gracefully from VRAM out of memory errors

This commit is contained in:
Lincoln Stein 2024-02-24 12:10:52 -05:00
parent 9986fce1a6
commit b3abc7252d

View File

@ -287,6 +287,14 @@ class ModelCache(object):
if torch.device(source_device).type == torch.device(target_device).type:
return
if target_device.type == "cuda":
vram_device = (
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
)
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if cache_entry.size > free_mem:
raise torch.cuda.OutOfMemoryError
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
@ -356,6 +364,10 @@ class ModelCache(object):
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self.cache.logger.warning("Out of GPU memory encountered.")
self.cache_entry.unlock()
raise
except Exception:
self.cache_entry.unlock()
raise
@ -524,7 +536,6 @@ class ModelCache(object):
break
if not cache_entry.locked and cache_entry.loaded:
self._move_model_to_device(model_key, self.storage_device)
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")