Add debug logging of changes in RAM and VRAM for all model cache operations.

This commit is contained in:
Ryan Dick 2023-09-28 16:27:16 -04:00
parent 44d68f5ed5
commit 594fd3ba6d

View File

@ -20,11 +20,13 @@ import gc
import hashlib
import os
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types
import psutil
import torch
import invokeai.backend.util.logging as logger
@ -211,17 +213,27 @@ class ModelCache(object):
if self.stats:
self.stats.misses += 1
# this will remove older cached models until
# there is sufficient room to load the requested model
self._make_cache_room(model_info.get_size(submodel))
self_reported_model_size_before_load = model_info.get_size(submodel)
# Remove old models from the cache to make room for the new model.
self._make_cache_room(self_reported_model_size_before_load)
# clean memory to make MemoryUsage() more accurate
gc.collect()
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f"CPU RAM used for load: {(mem_used/GIG):.2f} GB")
snapshot_after = MemorySnapshot.capture()
end_load_time = time.time()
cache_entry = _CacheRecord(self, model, mem_used)
self_reported_model_size_after_load = model_info.get_size(submodel)
self.logger.debug(
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s. Self-reported size"
f" before/after load: {(self_reported_model_size_before_load/GIG):.2f}GB /"
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
)
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
self._cached_models[key] = cache_entry
else:
if self.stats:
@ -271,10 +283,17 @@ class ModelCache(object):
self.cache._offload_unlocked_models(self.size_needed)
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
self.model.to(self.cache.execution_device) # move into GPU
snapshot_after = MemorySnapshot.capture()
end_model_to_time = time.time()
self.cache.logger.debug(
f"Moved model '{self.key}' from {self.cache.storage_device} to"
f" {self.cache.execution_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f" Estimated model size: {(self.cache_entry.size/GIG):.2f} GB."
f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
)
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
@ -427,11 +446,19 @@ class ModelCache(object):
if vram_in_use <= reserved:
break
if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f"Offloading {model_key} from {self.execution_device} into {self.storage_device}")
with VRAMUsage() as mem:
cache_entry.model.to(self.storage_device)
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
vram_in_use += mem.vram_used # note vram_used is negative
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
cache_entry.model.to(self.storage_device)
snapshot_after = MemorySnapshot.capture()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{model_key}' from {self.execution_device} to {self.storage_device} in"
f" {(end_model_to_time-start_model_to_time):.2f}s. Estimated model size:"
f" {(cache_entry.size/GIG):.2f} GB."
f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
)
vram_in_use = snapshot_after.vram or 0
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
@ -460,14 +487,60 @@ class ModelCache(object):
return hash
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
class MemorySnapshot:
"""A snapshot of RAM and VRAM usage. All values are in bytes."""
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self
def __init__(self, process_ram: int, vram: Optional[int]):
"""Initialize a MemorySnapshot.
def __exit__(self, *args):
self.vram_used = torch.cuda.memory_allocated() - self.vram
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
Args:
process_ram (int): CPU RAM used by the current process.
vram (Optional[int]): VRAM used by torch.
"""
self.process_ram = process_ram
self.vram = vram
@classmethod
def capture(cls, run_garbage_collector: bool = True):
"""Capture and return a MemorySnapshot.
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
Args:
run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
usage. Defaults to True.
Returns:
MemorySnapshot
"""
if run_garbage_collector:
gc.collect()
# According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
# supported on all platforms.
process_ram = psutil.Process().memory_info().rss
if choose_torch_device() == torch.device("cuda"):
vram = torch.cuda.memory_allocated()
else:
# TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
# time to test it properly.
vram = None
return cls(process_ram, vram)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
ram_diff = snapshot_2.process_ram - snapshot_1.process_ram
msg = f"RAM ({(ram_diff/GIG):+.2f}): {(snapshot_1.process_ram/GIG):.2f}GB -> {(snapshot_2.process_ram/GIG):.2f}GB"
vram_diff = None
if snapshot_1.vram is not None and snapshot_2.vram is not None:
vram_diff = snapshot_2.vram - snapshot_1.vram
msg += f", VRAM ({(vram_diff/GIG):+.2f}): {(snapshot_1.vram/GIG):.2f}GB -> {(snapshot_2.vram/GIG):.2f}GB"
return msg