From 24637104978d7b3ad6018ad944971b9b32b87233 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 24 Feb 2024 12:32:30 -0500 Subject: [PATCH] recover gracefuly from GPU out of memory errors (next version) --- .../load/model_cache/model_cache_default.py | 27 ++++++++++++++----- .../load/model_cache/model_locker.py | 3 ++- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 094b4f958c..e1c5e743c1 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -245,7 +245,13 @@ class ModelCache(ModelCacheBase[AnyModel]): mps.empty_cache() def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: - """Move model into the indicated device.""" + """Move model into the indicated device. + + :param cache_entry: The CacheRecord for the model + :param target_device: The torch.device to move the model into + + May raise a torch.cuda.OutOfMemoryError + """ # These attributes are not in the base ModelMixin class but in various derived classes. # Some models don't have these attributes, in which case they run in RAM/CPU. self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") @@ -259,6 +265,9 @@ class ModelCache(ModelCacheBase[AnyModel]): if torch.device(source_device).type == torch.device(target_device).type: return + # may raise an exception here if insufficient GPU VRAM + self._check_free_vram(target_device, cache_entry.size) + start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() cache_entry.model.to(target_device) @@ -294,12 +303,6 @@ class ModelCache(ModelCacheBase[AnyModel]): f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) - def _clear_vram(self) -> None: - """Called on out of memory errors. Moves all our models out of VRAM.""" - self.logger.warning('Resetting VRAM cache.') - for model in self._cached_models.values(): - self.move_model_to_device(model, torch.device('cpu')) - def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) @@ -411,3 +414,13 @@ class ModelCache(ModelCacheBase[AnyModel]): mps.empty_cache() self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") + + def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None: + if target_device.type != "cuda": + return + vram_device = ( # mem_get_info() needs an indexed device + target_device if target_device.index is not None else torch.device(str(target_device), index=0) + ) + free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device)) + if needed_size > free_mem: + raise torch.cuda.OutOfMemoryError diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 20c598e1cf..3651590cec 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -43,7 +43,8 @@ class ModelLocker(ModelLockerBase): self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.print_cuda_stats() except torch.cuda.OutOfMemoryError: - self._cache._clear_vram() + self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") + self._cache_entry.unlock() raise except Exception: self._cache_entry.unlock()