copy model from a meta device template

- temporarily disable vram cache
This commit is contained in:
Lincoln Stein 2024-06-24 10:55:15 -04:00
parent 6932f27b43
commit 2219e3643a
4 changed files with 40 additions and 122 deletions

View File

@ -52,11 +52,10 @@ class CacheRecord(Generic[T]):
Elements of the cache: Elements of the cache:
key: Unique key for each model, same as used in the models database. key: Unique key for each model, same as used in the models database.
model: Model in memory. model: Read-only copy of the model *without weights* residing in the "meta device"
state_dict: A read-only copy of the model's state dict in RAM. It will be state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM. used as a template for creating a copy in the VRAM.
size: Size of the model size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM, Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM and then injected into the model. When the model is finished, the VRAM
@ -72,25 +71,7 @@ class CacheRecord(Generic[T]):
key: str key: str
size: int size: int
model: T model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]] state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
@dataclass @dataclass

View File

@ -36,6 +36,7 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from ..optimizations import skip_torch_weight_init
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker from .model_locker import ModelLocker
@ -221,8 +222,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(model) size = calc_model_size_by_data(model)
self.make_room(size) self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None if isinstance(model, torch.nn.Module):
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) state_dict = model.state_dict() # keep a master copy of the state dict
model = model.to(device="meta") # and keep a template in the meta device
else:
state_dict = None
cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -284,48 +289,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
return model_key return model_key
def offload_unlocked_models(self, size_required: int) -> None: def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move any unused models from VRAM.""" """Move a copy of the model into the indicated device and return it.
device = self.get_execution_device()
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated(device) + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if cache_entry.device is not device:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model :param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into :param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError May raise a torch.cuda.OutOfMemoryError
""" """
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") self.logger.info(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. # Some models don't have a state dictionary, in which case the
# This would need to be revised to support multi-GPU. # stored model will still reside in CPU
if torch.device(source_device).type == torch.device(target_device).type: if cache_entry.state_dict is None:
return return cache_entry.model
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid # This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
@ -338,27 +315,25 @@ class ModelCache(ModelCacheBase[AnyModel]):
# in RAM into the model. So this operation is very fast. # in RAM into the model. So this operation is very fast.
start_model_to_time = time.time() start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot() snapshot_before = self._capture_memory_snapshot()
try: try:
if cache_entry.state_dict is not None: assert isinstance(cache_entry.model, torch.nn.Module)
assert hasattr(cache_entry.model, "load_state_dict") template = cache_entry.model
if target_device == self.storage_device: cls = template.__class__
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) with skip_torch_weight_init():
if hasattr(cls, "from_config"):
working_model = template.__class__.from_config(template.config) # diffusers style
else: else:
new_dict: Dict[str, torch.Tensor] = {} working_model = template.__class__(config=template.config) # transformers style (sigh)
for k, v in cache_entry.state_dict.items(): working_model.to(device=target_device, dtype=self._precision)
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True) working_model.load_state_dict(cache_entry.state_dict)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)
raise e raise e
snapshot_after = self._capture_memory_snapshot() snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time() end_model_to_time = time.time()
self.logger.debug( self.logger.info(
f"Moved model '{cache_entry.key}' from {source_device} to" f"Moved model '{cache_entry.key}' to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
@ -380,34 +355,21 @@ class ModelCache(ModelCacheBase[AnyModel]):
abs_tol=10 * MB, abs_tol=10 * MB,
): ):
self.logger.debug( self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to" f"Moving model '{cache_entry.key}' from to"
f" {target_device} caused an unexpected change in VRAM usage. The model's" f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:" " estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n" f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
return working_model
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG) ram = "%4.2fG" % (self.cache_size() / GIG)
in_ram_models = 0 in_ram_models = len(self._cached_models)
in_vram_models = 0 self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_in_vram_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None: def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.""" """Make enough room in the cache to accommodate a new model of indicated size."""
@ -433,29 +395,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
refs = sys.getrefcount(cache_entry.model) refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug( self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"

View File

@ -37,25 +37,22 @@ class ModelLocker(ModelLockerBase):
def lock(self) -> AnyModel: def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it.""" """Move the model into the execution device (GPU) and lock it."""
self._cache_entry.lock()
try: try:
device = self._cache.get_execution_device() device = self._cache.get_execution_device()
self._cache.offload_unlocked_models(self._cache_entry.size) model_on_device = self._cache.model_to_device(self._cache_entry, device)
self._cache.move_model_to_device(self._cache_entry, device) self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
raise raise
except Exception: except Exception:
self._cache_entry.unlock()
raise raise
return self.model return model_on_device
# It is no longer necessary to move the model out of VRAM
# because it will be removed when it goes out of scope
# in the caller's context
def unlock(self) -> None: def unlock(self) -> None:
"""Call upon exit from context.""" """Call upon exit from context."""
self._cache_entry.unlock()
self._cache.print_cuda_stats() self._cache.print_cuda_stats()

View File

@ -129,9 +129,7 @@ class ModelPatcher:
dtype = module.weight.dtype dtype = module.weight.dtype
if module_key not in original_weights: if module_key not in original_weights:
if model_state_dict is not None: # we were provided with the CPU copy of the state dict if model_state_dict is None: # no CPU copy of the state dict was provided
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
@ -158,6 +156,9 @@ class ModelPatcher:
yield # wait for context manager exit yield # wait for context manager exit
finally: finally:
# LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed.
# Therefore it should not be necessary to copy the original model weights back.
# This needs to be fixed before resurrecting the VRAM cache.
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad(): with torch.no_grad():
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():