mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
recover gracefully from VRAM out of memory errors
This commit is contained in:
parent
9986fce1a6
commit
b3abc7252d
@ -287,6 +287,14 @@ class ModelCache(object):
|
|||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if target_device.type == "cuda":
|
||||||
|
vram_device = (
|
||||||
|
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||||
|
)
|
||||||
|
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||||
|
if cache_entry.size > free_mem:
|
||||||
|
raise torch.cuda.OutOfMemoryError
|
||||||
|
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device)
|
||||||
@ -356,6 +364,10 @@ class ModelCache(object):
|
|||||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
self.cache.logger.warning("Out of GPU memory encountered.")
|
||||||
|
self.cache_entry.unlock()
|
||||||
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
@ -524,7 +536,6 @@ class ModelCache(object):
|
|||||||
break
|
break
|
||||||
if not cache_entry.locked and cache_entry.loaded:
|
if not cache_entry.locked and cache_entry.loaded:
|
||||||
self._move_model_to_device(model_key, self.storage_device)
|
self._move_model_to_device(model_key, self.storage_device)
|
||||||
|
|
||||||
vram_in_use = torch.cuda.memory_allocated()
|
vram_in_use = torch.cuda.memory_allocated()
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user