clear out VRAM when an OOM occurs

This commit is contained in:
Lincoln Stein 2024-02-24 11:25:40 -05:00 committed by psychedelicious
parent fbd9ffdc5a
commit d22738723d
2 changed files with 10 additions and 2 deletions

View File

@ -294,6 +294,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
def _clear_vram(self) -> None:
"""Called on out of memory errors. Moves all our models out of VRAM."""
self.logger.warning('Resetting VRAM cache.')
for model in self._cached_models.values():
self.move_model_to_device(model, torch.device('cpu'))
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)

View File

@ -2,8 +2,8 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
import torch
from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
@ -42,7 +42,9 @@ class ModelLocker(ModelLockerBase):
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache._clear_vram()
raise
except Exception: except Exception:
self._cache_entry.unlock() self._cache_entry.unlock()
raise raise