[mm] clear the cache entry for a model that got an OOM during loading

This commit is contained in:
Lincoln Stein 2024-04-10 10:26:38 -04:00 committed by psychedelicious
parent 7bc77ddb40
commit 579082ac10
3 changed files with 16 additions and 7 deletions

View File

@ -24,9 +24,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = (
True
)
mask[
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
] = True
mask_tensor_name = context.tensors.save(mask)
return MaskOutput(

View File

@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):

View File

@ -271,7 +271,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
try:
cache_entry.model.to(target_device)
except torch.cuda.OutOfMemoryError as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
@ -389,8 +394,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
current_size -= cache_entry.size
models_cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
self._delete_cache_entry(cache_entry)
del cache_entry
else:
@ -417,3 +421,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
mps.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]