Add log_memory_usage param to ModelCache.

This commit is contained in:
Ryan Dick 2023-11-02 12:00:07 -04:00 committed by Kent Keirsey
parent 267e709ba2
commit 3781e56e57

View File

@ -117,6 +117,7 @@ class ModelCache(object):
lazy_offloading: bool = True, lazy_offloading: bool = True,
sha_chunksize: int = 16777216, sha_chunksize: int = 16777216,
logger: types.ModuleType = logger, logger: types.ModuleType = logger,
log_memory_usage: bool = False,
): ):
""" """
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@ -126,6 +127,10 @@ class ModelCache(object):
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
""" """
self.model_infos: Dict[str, ModelBase] = dict() self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled # allow lazy offloading only when vram cache enabled
@ -137,6 +142,7 @@ class ModelCache(object):
self.storage_device: torch.device = storage_device self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize self.sha_chunksize = sha_chunksize
self.logger = logger self.logger = logger
self._log_memory_usage = log_memory_usage
# used for stats collection # used for stats collection
self.stats = None self.stats = None
@ -144,6 +150,11 @@ class ModelCache(object):
self._cached_models = dict() self._cached_models = dict()
self._cache_stack = list() self._cache_stack = list()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def get_key( def get_key(
self, self,
model_path: str, model_path: str,
@ -223,10 +234,10 @@ class ModelCache(object):
# Load the model from disk and capture a memory snapshot before/after. # Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time() start_load_time = time.time()
snapshot_before = MemorySnapshot.capture() snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init(): with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture() snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time() end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel) self_reported_model_size_after_load = model_info.get_size(submodel)
@ -275,9 +286,9 @@ class ModelCache(object):
return return
start_model_to_time = time.time() start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture() snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device) cache_entry.model.to(target_device)
snapshot_after = MemorySnapshot.capture() snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time() end_model_to_time = time.time()
self.logger.debug( self.logger.debug(
f"Moved model '{key}' from {source_device} to" f"Moved model '{key}' from {source_device} to"
@ -286,7 +297,12 @@ class ModelCache(object):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
if snapshot_before.vram is not None and snapshot_after.vram is not None: if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram) vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning. # If the estimated model size does not match the change in VRAM, log a warning.