Add log_memory_usage param to ModelCache.

This commit is contained in:
Ryan Dick 2023-11-02 12:00:07 -04:00 committed by Kent Keirsey
parent 267e709ba2
commit 3781e56e57

View File

@ -117,6 +117,7 @@ class ModelCache(object):
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger,
log_memory_usage: bool = False,
):
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@ -126,6 +127,10 @@ class ModelCache(object):
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled
@ -137,6 +142,7 @@ class ModelCache(object):
self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize
self.logger = logger
self._log_memory_usage = log_memory_usage
# used for stats collection
self.stats = None
@ -144,6 +150,11 @@ class ModelCache(object):
self._cached_models = dict()
self._cache_stack = list()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def get_key(
self,
model_path: str,
@ -223,10 +234,10 @@ class ModelCache(object):
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture()
snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel)
@ -275,9 +286,9 @@ class ModelCache(object):
return
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
snapshot_after = MemorySnapshot.capture()
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{key}' from {source_device} to"
@ -286,7 +297,12 @@ class ModelCache(object):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if snapshot_before.vram is not None and snapshot_after.vram is not None:
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.