mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
recover gracefuly from GPU out of memory errors (next version)
This commit is contained in:
parent
d22738723d
commit
371e3cc260
@ -245,7 +245,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
"""Move model into the indicated device."""
|
"""Move model into the indicated device.
|
||||||
|
|
||||||
|
:param cache_entry: The CacheRecord for the model
|
||||||
|
:param target_device: The torch.device to move the model into
|
||||||
|
|
||||||
|
May raise a torch.cuda.OutOfMemoryError
|
||||||
|
"""
|
||||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
# These attributes are not in the base ModelMixin class but in various derived classes.
|
||||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||||
@ -259,6 +265,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# may raise an exception here if insufficient GPU VRAM
|
||||||
|
self._check_free_vram(target_device, cache_entry.size)
|
||||||
|
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device)
|
||||||
@ -294,12 +303,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _clear_vram(self) -> None:
|
|
||||||
"""Called on out of memory errors. Moves all our models out of VRAM."""
|
|
||||||
self.logger.warning('Resetting VRAM cache.')
|
|
||||||
for model in self._cached_models.values():
|
|
||||||
self.move_model_to_device(model, torch.device('cpu'))
|
|
||||||
|
|
||||||
def print_cuda_stats(self) -> None:
|
def print_cuda_stats(self) -> None:
|
||||||
"""Log CUDA diagnostics."""
|
"""Log CUDA diagnostics."""
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
@ -411,3 +414,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
|
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
||||||
|
if target_device.type != "cuda":
|
||||||
|
return
|
||||||
|
vram_device = ( # mem_get_info() needs an indexed device
|
||||||
|
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||||
|
)
|
||||||
|
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||||
|
if needed_size > free_mem:
|
||||||
|
raise torch.cuda.OutOfMemoryError
|
||||||
|
@ -43,7 +43,8 @@ class ModelLocker(ModelLockerBase):
|
|||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
self._cache._clear_vram()
|
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||||
|
self._cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
|
Loading…
Reference in New Issue
Block a user