Offload the current model when locking if it is already partially loaded and we have insufficient VRAM.

This commit is contained in:
Ryan Dick
2025-01-07 02:53:44 +00:00
parent 5eafe1ec7a
commit d7ab464176
3 changed files with 59 additions and 7 deletions

View File

@ -166,13 +166,17 @@ class CachedModelWithPartialLoad:
return vram_bytes_loaded
@torch.no_grad()
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weights_in_vram: bool = False) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
:param keep_required_weights_in_vram: If True, any weights that must be kept in VRAM to run the model will be
kept in VRAM.
Returns:
The number of bytes unloaded from VRAM.
"""
vram_bytes_freed = 0
required_weights_in_vram = 0
offload_device = "cpu"
cur_state_dict = self._model.state_dict()
@ -183,6 +187,10 @@ class CachedModelWithPartialLoad:
if param.device.type == offload_device:
continue
if keep_required_weights_in_vram and key in self._keys_in_modules_that_do_not_support_autocast:
required_weights_in_vram += self._state_dict_bytes[key]
continue
cur_state_dict[key] = self._cpu_state_dict[key]
vram_bytes_freed += self._state_dict_bytes[key]

View File

@ -269,7 +269,7 @@ class ModelCache:
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
# possible.
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed)
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed, working_mem_bytes)
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
# Check the updated vram_available after offloading.
@ -278,6 +278,15 @@ class ModelCache:
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
if vram_available < 0:
# There is insufficient VRAM available. As a last resort, try to unload the model being locked from VRAM,
# as it may still be loaded from a previous use.
vram_bytes_freed_from_own_model = self._move_model_to_ram(cache_entry, -vram_available)
vram_available = self._get_vram_available(working_mem_bytes)
self._logger.debug(
f"Unloaded {vram_bytes_freed_from_own_model/MB:.2f}MB from the model being locked ({cache_entry.key})."
)
# Move as much of the model as possible into VRAM.
# For testing, only allow 10% of the model to be loaded into VRAM.
# vram_available = int(model_vram_needed * 0.1)
@ -318,7 +327,9 @@ class ModelCache:
def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int:
try:
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
return cache_entry.cached_model.partial_unload_from_vram(vram_bytes_to_free)
return cache_entry.cached_model.partial_unload_from_vram(
vram_bytes_to_free, keep_required_weights_in_vram=cache_entry.is_locked
)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
return cache_entry.cached_model.full_unload_from_vram()
else:
@ -328,7 +339,7 @@ class ModelCache:
self._delete_cache_entry(cache_entry)
raise
def _get_vram_available(self, working_mem_bytes: Optional[int] = None) -> int:
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
memory).
"""
@ -421,7 +432,7 @@ class ModelCache:
+ f"vram_available={(vram_available/MB):.0f} MB, "
)
def _offload_unlocked_models(self, vram_bytes_required: int) -> int:
def _offload_unlocked_models(self, vram_bytes_required: int, working_mem_bytes: Optional[int] = None) -> int:
"""Offload models from the execution_device until vram_bytes_required bytes are available, or all models are
offloaded. Of course, locked models are not offloaded.
@ -436,11 +447,13 @@ class ModelCache:
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
for cache_entry in cache_entries_increasing_size:
# We do not fully trust the count of bytes freed, so we check again on each iteration.
vram_available = self._get_vram_available()
vram_available = self._get_vram_available(working_mem_bytes)
vram_bytes_to_free = vram_bytes_required - vram_available
if vram_bytes_to_free <= 0:
break
if cache_entry.is_locked:
# TODO(ryand): In the future, we may want to partially unload locked models, but this requires careful
# handling of model patches (e.g. LoRA).
continue
cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free)
if cache_entry_bytes_freed > 0:
@ -478,7 +491,7 @@ class ModelCache:
if self._execution_device.type != "cpu":
vram_in_use_bytes = self._get_vram_in_use()
vram_available_bytes = self._get_vram_available()
vram_available_bytes = self._get_vram_available(None)
vram_size_bytes = vram_in_use_bytes + vram_available_bytes
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0

View File

@ -98,6 +98,37 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
assert model.linear2.is_device_autocasting_enabled()
@parameterize_mps_and_cuda
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() == model_total_bytes
# Partially unload the model from VRAM, but request the required weights to be kept in VRAM.
bytes_to_free = int(model_total_bytes)
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free, keep_required_weights_in_vram=True)
# Check that the model is partially unloaded from VRAM.
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
assert freed_bytes == sum(
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
)
# The parameters should be offloaded to the CPU, because they are in Linear layers.
assert all(p.device.type == "cpu" for p in model.parameters())
# The buffer should still be on the device, because it is in a layer that does not support autocast.
assert all(p.device.type == device for p in model.buffers())
# Check that the model's modules still have device autocasting enabled.
assert model.linear1.is_device_autocasting_enabled()
assert model.linear2.is_device_autocasting_enabled()
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))