mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Offload the current model when locking if it is already partially loaded and we have insufficient VRAM.
This commit is contained in:
@ -166,13 +166,17 @@ class CachedModelWithPartialLoad:
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weights_in_vram: bool = False) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
:param keep_required_weights_in_vram: If True, any weights that must be kept in VRAM to run the model will be
|
||||
kept in VRAM.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
required_weights_in_vram = 0
|
||||
|
||||
offload_device = "cpu"
|
||||
cur_state_dict = self._model.state_dict()
|
||||
@ -183,6 +187,10 @@ class CachedModelWithPartialLoad:
|
||||
if param.device.type == offload_device:
|
||||
continue
|
||||
|
||||
if keep_required_weights_in_vram and key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
required_weights_in_vram += self._state_dict_bytes[key]
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
vram_bytes_freed += self._state_dict_bytes[key]
|
||||
|
||||
|
@ -269,7 +269,7 @@ class ModelCache:
|
||||
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
|
||||
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
|
||||
# possible.
|
||||
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed)
|
||||
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed, working_mem_bytes)
|
||||
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
||||
|
||||
# Check the updated vram_available after offloading.
|
||||
@ -278,6 +278,15 @@ class ModelCache:
|
||||
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
if vram_available < 0:
|
||||
# There is insufficient VRAM available. As a last resort, try to unload the model being locked from VRAM,
|
||||
# as it may still be loaded from a previous use.
|
||||
vram_bytes_freed_from_own_model = self._move_model_to_ram(cache_entry, -vram_available)
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"Unloaded {vram_bytes_freed_from_own_model/MB:.2f}MB from the model being locked ({cache_entry.key})."
|
||||
)
|
||||
|
||||
# Move as much of the model as possible into VRAM.
|
||||
# For testing, only allow 10% of the model to be loaded into VRAM.
|
||||
# vram_available = int(model_vram_needed * 0.1)
|
||||
@ -318,7 +327,9 @@ class ModelCache:
|
||||
def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int:
|
||||
try:
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
return cache_entry.cached_model.partial_unload_from_vram(vram_bytes_to_free)
|
||||
return cache_entry.cached_model.partial_unload_from_vram(
|
||||
vram_bytes_to_free, keep_required_weights_in_vram=cache_entry.is_locked
|
||||
)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
return cache_entry.cached_model.full_unload_from_vram()
|
||||
else:
|
||||
@ -328,7 +339,7 @@ class ModelCache:
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise
|
||||
|
||||
def _get_vram_available(self, working_mem_bytes: Optional[int] = None) -> int:
|
||||
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
|
||||
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
|
||||
memory).
|
||||
"""
|
||||
@ -421,7 +432,7 @@ class ModelCache:
|
||||
+ f"vram_available={(vram_available/MB):.0f} MB, "
|
||||
)
|
||||
|
||||
def _offload_unlocked_models(self, vram_bytes_required: int) -> int:
|
||||
def _offload_unlocked_models(self, vram_bytes_required: int, working_mem_bytes: Optional[int] = None) -> int:
|
||||
"""Offload models from the execution_device until vram_bytes_required bytes are available, or all models are
|
||||
offloaded. Of course, locked models are not offloaded.
|
||||
|
||||
@ -436,11 +447,13 @@ class ModelCache:
|
||||
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
|
||||
for cache_entry in cache_entries_increasing_size:
|
||||
# We do not fully trust the count of bytes freed, so we check again on each iteration.
|
||||
vram_available = self._get_vram_available()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
vram_bytes_to_free = vram_bytes_required - vram_available
|
||||
if vram_bytes_to_free <= 0:
|
||||
break
|
||||
if cache_entry.is_locked:
|
||||
# TODO(ryand): In the future, we may want to partially unload locked models, but this requires careful
|
||||
# handling of model patches (e.g. LoRA).
|
||||
continue
|
||||
cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free)
|
||||
if cache_entry_bytes_freed > 0:
|
||||
@ -478,7 +491,7 @@ class ModelCache:
|
||||
|
||||
if self._execution_device.type != "cpu":
|
||||
vram_in_use_bytes = self._get_vram_in_use()
|
||||
vram_available_bytes = self._get_vram_available()
|
||||
vram_available_bytes = self._get_vram_available(None)
|
||||
vram_size_bytes = vram_in_use_bytes + vram_available_bytes
|
||||
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
|
@ -98,6 +98,37 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
assert model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
# Partially unload the model from VRAM, but request the required weights to be kept in VRAM.
|
||||
bytes_to_free = int(model_total_bytes)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free, keep_required_weights_in_vram=True)
|
||||
|
||||
# Check that the model is partially unloaded from VRAM.
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
assert freed_bytes == sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
|
||||
)
|
||||
# The parameters should be offloaded to the CPU, because they are in Linear layers.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
# The buffer should still be on the device, because it is in a layer that does not support autocast.
|
||||
assert all(p.device.type == device for p in model.buffers())
|
||||
|
||||
# Check that the model's modules still have device autocasting enabled.
|
||||
assert model.linear1.is_device_autocasting_enabled()
|
||||
assert model.linear2.is_device_autocasting_enabled()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
Reference in New Issue
Block a user