catch RunTimeError during model to() call rather than OutOfMemoryError

This commit is contained in:
Lincoln Stein 2024-04-10 23:44:14 -04:00 committed by psychedelicious
parent dedf0c6ffa
commit 46d23cd868

View File

@ -273,7 +273,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
snapshot_before = self._capture_memory_snapshot()
try:
cache_entry.model.to(target_device)
except torch.cuda.OutOfMemoryError as e: # blow away cache entry
except RuntimeError as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e