From d22738723da5ebd065b78fff2dabc22064566f83 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 24 Feb 2024 11:25:40 -0500 Subject: [PATCH] clear out VRAM when an OOM occurs --- .../model_manager/load/model_cache/model_cache_default.py | 6 ++++++ .../backend/model_manager/load/model_cache/model_locker.py | 6 ++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 02ce1266c7..094b4f958c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -294,6 +294,12 @@ class ModelCache(ModelCacheBase[AnyModel]): f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) + def _clear_vram(self) -> None: + """Called on out of memory errors. Moves all our models out of VRAM.""" + self.logger.warning('Resetting VRAM cache.') + for model in self._cached_models.values(): + self.move_model_to_device(model, torch.device('cpu')) + def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 7a5fdd4284..20c598e1cf 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -2,8 +2,8 @@ Base class and implementation of a class that moves models in and out of VRAM. """ +import torch from invokeai.backend.model_manager import AnyModel - from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase @@ -42,7 +42,9 @@ class ModelLocker(ModelLockerBase): self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.print_cuda_stats() - + except torch.cuda.OutOfMemoryError: + self._cache._clear_vram() + raise except Exception: self._cache_entry.unlock() raise