# Copyright 2023 Lincoln D. Stein """Utility to collect execution time and GPU usage stats on invocations in flight Usage: statistics = InvocationStatsService(graph_execution_manager) with statistics.collect_stats(invocation, graph_execution_state.id): ... execute graphs... statistics.log_stats() Typical output: [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3 [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used [2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s [2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G The abstract base class for this class is InvocationStatsServiceBase. An implementing class which writes to the system log is stored in InvocationServices.performance_statistics. """ import time from abc import ABC, abstractmethod from contextlib import AbstractContextManager from dataclasses import dataclass, field from typing import Dict import psutil import torch import invokeai.backend.util.logging as logger from invokeai.backend.model_management.model_cache import CacheStats from ..invocations.baseinvocation import BaseInvocation from .graph import GraphExecutionState from .item_storage import ItemStorageABC from .model_manager_service import ModelManagerService # size of GIG in bytes GIG = 1073741824 @dataclass class NodeStats: """Class for tracking execution stats of an invocation node""" calls: int = 0 time_used: float = 0.0 # seconds max_vram: float = 0.0 # GB cache_hits: int = 0 cache_misses: int = 0 cache_high_watermark: int = 0 @dataclass class NodeLog: """Class for tracking node usage""" # {node_type => NodeStats} nodes: Dict[str, NodeStats] = field(default_factory=dict) class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" graph_execution_manager: ItemStorageABC["GraphExecutionState"] # {graph_id => NodeLog} _stats: Dict[str, NodeLog] _cache_stats: Dict[str, CacheStats] ram_used: float ram_changed: float @abstractmethod def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): """ Initialize the InvocationStatsService and reset counters to zero :param graph_execution_manager: Graph execution manager for this session """ pass @abstractmethod def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, ) -> AbstractContextManager: """ Return a context object that will capture the statistics on the execution of invocaation. Use with: to place around the part of the code that executes the invocation. :param invocation: BaseInvocation object from the current graph. :param graph_execution_state: GraphExecutionState object from the current session. """ pass @abstractmethod def reset_stats(self, graph_execution_state_id: str): """ Reset all statistics for the indicated graph :param graph_execution_state_id """ pass @abstractmethod def reset_all_stats(self): """Zero all statistics""" pass @abstractmethod def update_invocation_stats( self, graph_id: str, invocation_type: str, time_used: float, vram_used: float, ): """ Add timing information on execution of a node. Usually used internally. :param graph_id: ID of the graph that is currently executing :param invocation_type: String literal type of the node :param time_used: Time used by node's exection (sec) :param vram_used: Maximum VRAM used during exection (GB) """ pass @abstractmethod def log_stats(self): """ Write out the accumulated statistics to the log or somewhere else. """ pass @abstractmethod def update_mem_stats( self, ram_used: float, ram_changed: float, ): """ Update the collector with RAM memory usage info. :param ram_used: How much RAM is currently in use. :param ram_changed: How much RAM changed since last generation. """ pass class InvocationStatsService(InvocationStatsServiceBase): """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): self.graph_execution_manager = graph_execution_manager # {graph_id => NodeLog} self._stats: Dict[str, NodeLog] = {} self._cache_stats: Dict[str, CacheStats] = {} self.ram_used: float = 0.0 self.ram_changed: float = 0.0 class StatsContext: """Context manager for collecting statistics.""" invocation: BaseInvocation collector: "InvocationStatsServiceBase" graph_id: str start_time: float ram_used: int model_manager: ModelManagerService def __init__( self, invocation: BaseInvocation, graph_id: str, model_manager: ModelManagerService, collector: "InvocationStatsServiceBase", ): """Initialize statistics for this run.""" self.invocation = invocation self.collector = collector self.graph_id = graph_id self.start_time = 0.0 self.ram_used = 0 self.model_manager = model_manager def __enter__(self): self.start_time = time.time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() self.ram_used = psutil.Process().memory_info().rss if self.model_manager: self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id]) def __exit__(self, *args): """Called on exit from the context.""" ram_used = psutil.Process().memory_info().rss self.collector.update_mem_stats( ram_used=ram_used / GIG, ram_changed=(ram_used - self.ram_used) / GIG, ) self.collector.update_invocation_stats( graph_id=self.graph_id, invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations time_used=time.time() - self.start_time, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, ) def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, model_manager: ModelManagerService, ) -> StatsContext: if not self._stats.get(graph_execution_state_id): # first time we're seeing this self._stats[graph_execution_state_id] = NodeLog() self._cache_stats[graph_execution_state_id] = CacheStats() return self.StatsContext(invocation, graph_execution_state_id, model_manager, self) def reset_all_stats(self): """Zero all statistics""" self._stats = {} def reset_stats(self, graph_execution_id: str): try: self._stats.pop(graph_execution_id) except KeyError: logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") def update_mem_stats( self, ram_used: float, ram_changed: float, ): self.ram_used = ram_used self.ram_changed = ram_changed def update_invocation_stats( self, graph_id: str, invocation_type: str, time_used: float, vram_used: float, ): if not self._stats[graph_id].nodes.get(invocation_type): self._stats[graph_id].nodes[invocation_type] = NodeStats() stats = self._stats[graph_id].nodes[invocation_type] stats.calls += 1 stats.time_used += time_used stats.max_vram = max(stats.max_vram, vram_used) def log_stats(self): completed = set() errored = set() for graph_id, node_log in self._stats.items(): try: current_graph_state = self.graph_execution_manager.get(graph_id) except Exception: errored.add(graph_id) continue if not current_graph_state.is_complete(): continue total_time = 0 logger.info(f"Graph stats: {graph_id}") logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}") for node_type, stats in self._stats[graph_id].nodes.items(): logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G") total_time += stats.time_used cache_stats = self._cache_stats[graph_id] hwm = cache_stats.high_watermark / GIG tot = cache_stats.cache_size / GIG loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)") logger.info(f"RAM used to load models: {loaded:4.2f}G") if torch.cuda.is_available(): logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG)) logger.info("RAM cache statistics:") logger.info(f" Model cache hits: {cache_stats.hits}") logger.info(f" Model cache misses: {cache_stats.misses}") logger.info(f" Models cached: {cache_stats.in_cache}") logger.info(f" Models cleared from cache: {cache_stats.cleared}") logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G") completed.add(graph_id) for graph_id in completed: del self._stats[graph_id] del self._cache_stats[graph_id] for graph_id in errored: del self._stats[graph_id] del self._cache_stats[graph_id]