report RAM and RAM cache statistics

This commit is contained in:
Lincoln Stein
2023-08-15 21:00:30 -04:00
parent a4b029d03c
commit ec10aca91e
4 changed files with 115 additions and 36 deletions

View File

@ -43,6 +43,11 @@ import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from invokeai.backend.model_management.model_cache import CacheStats
# size of GIG in bytes
GIG = 1073741824
class InvocationStatsServiceBase(ABC):
@ -84,14 +89,15 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def update_invocation_stats(self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
@ -119,6 +125,9 @@ class NodeStats:
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
@ -137,36 +146,50 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
"""Context manager for collecting statistics."""
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerService,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
self.ram_info = None
self.ram_used = 0
self.model_manager = model_manager
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_info = psutil.virtual_memory()
self.ram_used = psutil.Process().memory_info().rss
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_invocation_stats(
graph_id = self.graph_id,
invocation_type = self.invocation.type,
time_used = time.time() - self.start_time,
vram_used = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
ram_used = psutil.virtual_memory().used / 1e9,
ram_changed = (psutil.virtual_memory().used - self.ram_info.used) / 1e9,
graph_id=self.graph_id,
invocation_type=self.invocation.type,
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
@ -175,7 +198,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
def reset_all_stats(self):
"""Zero all statistics"""
@ -188,14 +212,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
@ -218,7 +243,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set()
@ -235,11 +260,21 @@ class InvocationStatsService(InvocationStatsServiceBase):
total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("Current RAM used: " + "%4.2fG" % stats.ram_used + f" (delta={stats.ram_changed:4.2f}G)")
logger.info("RAM used: " + "%4.2fG" % stats.ram_used + f" (delta={stats.ram_changed:4.2f}G)")
if torch.cuda.is_available():
logger.info("Current VRAM used: " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
logger.info("VRAM used (all processes): " + "%4.2fG" % (torch.cuda.memory_allocated() / GIG))
cache_stats = self._cache_stats[graph_id]
logger.info("RAM cache statistics:")
logger.info(f" Model cache hits: {cache_stats.hits}")
logger.info(f" Model cache misses: {cache_stats.misses}")
logger.info(f" Models cached: {cache_stats.in_cache}")
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
logger.info(f" Cache RAM usage: {hwm:4.2f}/{tot:4.2f}G")
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_management import (
ModelNotFoundException,
)
from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.model_management.model_cache import CacheStats
import torch
from invokeai.app.models.exceptions import CanceledException
@ -276,6 +277,13 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
@ -500,6 +508,12 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.

View File

@ -86,7 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state.id):
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
outputs = invocation.invoke(
InvocationContext(
services=self.__invoker.services,