[mm] clear the cache entry for a model that got an OOM during loading

This commit is contained in:
Lincoln Stein
2024-04-10 10:26:38 -04:00
committed by psychedelicious
parent 7bc77ddb40
commit 579082ac10
3 changed files with 16 additions and 7 deletions

View File

@ -271,7 +271,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
try:
cache_entry.model.to(target_device)
except torch.cuda.OutOfMemoryError as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
@ -389,8 +394,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
current_size -= cache_entry.size
models_cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
self._delete_cache_entry(cache_entry)
del cache_entry
else:
@ -417,3 +421,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
mps.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]