diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index 7db653c3fb..365bd9cb6c 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -30,23 +30,13 @@ writes to the system log is stored in InvocationServices.performance_statistics. from abc import ABC, abstractmethod from contextlib import AbstractContextManager -from typing import Dict from invokeai.app.invocations.baseinvocation import BaseInvocation -from invokeai.backend.model_management.model_cache import CacheStats - -from .invocation_stats_common import NodeLog class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" - # {graph_id => NodeLog} - _stats: Dict[str, NodeLog] - _cache_stats: Dict[str, CacheStats] - ram_used: float - ram_changed: float - @abstractmethod def __init__(self): """ @@ -76,46 +66,9 @@ class InvocationStatsServiceBase(ABC): """ pass - @abstractmethod - def reset_all_stats(self): - """Zero all statistics""" - pass - - @abstractmethod - def update_invocation_stats( - self, - graph_id: str, - invocation_type: str, - time_used: float, - vram_used: float, - ): - """ - Add timing information on execution of a node. Usually - used internally. - :param graph_id: ID of the graph that is currently executing - :param invocation_type: String literal type of the node - :param time_used: Time used by node's exection (sec) - :param vram_used: Maximum VRAM used during exection (GB) - """ - pass - @abstractmethod def log_stats(self): """ Write out the accumulated statistics to the log or somewhere else. """ pass - - @abstractmethod - def update_mem_stats( - self, - ram_used: float, - ram_changed: float, - ): - """ - Update the collector with RAM memory usage info. - - :param ram_used: How much RAM is currently in use. - :param ram_changed: How much RAM changed since last generation. - """ - pass diff --git a/invokeai/app/services/invocation_stats/invocation_stats_common.py b/invokeai/app/services/invocation_stats/invocation_stats_common.py index 19b954f6da..0059efc552 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_common.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_common.py @@ -1,25 +1,87 @@ -from dataclasses import dataclass, field -from typing import Dict +from collections import defaultdict +from dataclasses import dataclass # size of GIG in bytes GIG = 1073741824 @dataclass -class NodeStats: - """Class for tracking execution stats of an invocation node""" +class NodeExecutionStats: + """Class for tracking execution stats of an invocation node.""" - calls: int = 0 - time_used: float = 0.0 # seconds - max_vram: float = 0.0 # GB - cache_hits: int = 0 - cache_misses: int = 0 - cache_high_watermark: int = 0 + invocation_type: str + + start_time: float # Seconds since the epoch. + end_time: float # Seconds since the epoch. + + start_ram_gb: float # GB + end_ram_gb: float # GB + + peak_vram_gb: float # GB + + def total_time(self) -> float: + return self.end_time - self.start_time -@dataclass -class NodeLog: - """Class for tracking node usage""" +class GraphExecutionStats: + """Class for tracking execution stats of a graph.""" - # {node_type => NodeStats} - nodes: Dict[str, NodeStats] = field(default_factory=dict) + def __init__(self): + self._node_stats_list: list[NodeExecutionStats] = [] + + def add_node_execution_stats(self, node_stats: NodeExecutionStats): + self._node_stats_list.append(node_stats) + + def get_total_run_time(self) -> float: + """Get the total time spent executing nodes in the graph.""" + total = 0.0 + for node_stats in self._node_stats_list: + total += node_stats.total_time() + return total + + def get_first_node_stats(self) -> NodeExecutionStats | None: + """Get the stats of the first node in the graph (by start_time).""" + first_node = None + for node_stats in self._node_stats_list: + if first_node is None or node_stats.start_time < first_node.start_time: + first_node = node_stats + + assert first_node is not None + return first_node + + def get_last_node_stats(self) -> NodeExecutionStats | None: + """Get the stats of the last node in the graph (by end_time).""" + last_node = None + for node_stats in self._node_stats_list: + if last_node is None or node_stats.end_time > last_node.end_time: + last_node = node_stats + + return last_node + + def get_pretty_log(self, graph_execution_state_id: str) -> str: + log = f"Graph stats: {graph_execution_state_id}\n" + log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n" + + # Log stats aggregated by node type. + node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list) + for node_stats in self._node_stats_list: + node_stats_by_type[node_stats.invocation_type].append(node_stats) + + for node_type, node_type_stats_list in node_stats_by_type.items(): + num_calls = len(node_type_stats_list) + time_used = sum([n.total_time() for n in node_type_stats_list]) + peak_vram = max([n.peak_vram_gb for n in node_type_stats_list]) + log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n" + + # Log stats for the entire graph. + log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n" + + first_node = self.get_first_node_stats() + last_node = self.get_last_node_stats() + if first_node is not None and last_node is not None: + total_wall_time = last_node.end_time - first_node.start_time + ram_change = last_node.end_ram_gb - first_node.start_ram_gb + log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n" + log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n" + + return log diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 34d2cd8354..50b1a1bee7 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -1,5 +1,5 @@ import time -from typing import Dict +from contextlib import contextmanager import psutil import torch @@ -7,85 +7,54 @@ import torch import invokeai.backend.util.logging as logger from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker -from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase from invokeai.backend.model_management.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase -from .invocation_stats_common import GIG, NodeLog, NodeStats +from .invocation_stats_common import GIG, GraphExecutionStats, NodeExecutionStats class InvocationStatsService(InvocationStatsServiceBase): """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" - _invoker: Invoker - def __init__(self): - # {graph_id => NodeLog} - self._stats: Dict[str, NodeLog] = {} - self._cache_stats: Dict[str, CacheStats] = {} - self.ram_used: float = 0.0 - self.ram_changed: float = 0.0 + # Maps graph_execution_state_id to GraphExecutionStats. + self._stats: dict[str, GraphExecutionStats] = {} + # Maps graph_execution_state_id to model manager CacheStats. + self._cache_stats: dict[str, CacheStats] = {} def start(self, invoker: Invoker) -> None: self._invoker = invoker - class StatsContext: - """Context manager for collecting statistics.""" - - invocation: BaseInvocation - collector: "InvocationStatsServiceBase" - graph_id: str - start_time: float - ram_used: int - model_manager: ModelManagerServiceBase - - def __init__( - self, - invocation: BaseInvocation, - graph_id: str, - model_manager: ModelManagerServiceBase, - collector: "InvocationStatsServiceBase", - ): - """Initialize statistics for this run.""" - self.invocation = invocation - self.collector = collector - self.graph_id = graph_id - self.start_time = 0.0 - self.ram_used = 0 - self.model_manager = model_manager - - def __enter__(self): - self.start_time = time.time() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - self.ram_used = psutil.Process().memory_info().rss - if self.model_manager: - self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id]) - - def __exit__(self, *args): - """Called on exit from the context.""" - ram_used = psutil.Process().memory_info().rss - self.collector.update_mem_stats( - ram_used=ram_used / GIG, - ram_changed=(ram_used - self.ram_used) / GIG, - ) - self.collector.update_invocation_stats( - graph_id=self.graph_id, - invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations - time_used=time.time() - self.start_time, - vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, - ) - - def collect_stats( - self, - invocation: BaseInvocation, - graph_execution_state_id: str, - ) -> StatsContext: - if not self._stats.get(graph_execution_state_id): # first time we're seeing this - self._stats[graph_execution_state_id] = NodeLog() + @contextmanager + def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str): + if not self._stats.get(graph_execution_state_id): + # First time we're seeing this graph_execution_state_id. + self._stats[graph_execution_state_id] = GraphExecutionStats() self._cache_stats[graph_execution_state_id] = CacheStats() - return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self) + + # Record state before the invocation. + start_time = time.time() + start_ram = psutil.Process().memory_info().rss + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + if self._invoker.services.model_manager: + self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id]) + + try: + # Let the invocation run. + yield None + finally: + # Record state after the invocation. + node_stats = NodeExecutionStats( + invocation_type=invocation.type, + start_time=start_time, + end_time=time.time(), + start_ram_gb=start_ram / GIG, + end_ram_gb=psutil.Process().memory_info().rss / GIG, + peak_vram_gb=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, + ) + self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) def reset_all_stats(self): """Zero all statistics""" @@ -97,28 +66,6 @@ class InvocationStatsService(InvocationStatsServiceBase): except KeyError: logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") - def update_mem_stats( - self, - ram_used: float, - ram_changed: float, - ): - self.ram_used = ram_used - self.ram_changed = ram_changed - - def update_invocation_stats( - self, - graph_id: str, - invocation_type: str, - time_used: float, - vram_used: float, - ): - if not self._stats[graph_id].nodes.get(invocation_type): - self._stats[graph_id].nodes[invocation_type] = NodeStats() - stats = self._stats[graph_id].nodes[invocation_type] - stats.calls += 1 - stats.time_used += time_used - stats.max_vram = max(stats.max_vram, vram_used) - def log_stats(self): completed = set() errored = set() @@ -132,29 +79,23 @@ class InvocationStatsService(InvocationStatsServiceBase): if not current_graph_state.is_complete(): continue - total_time = 0 - logger.info(f"Graph stats: {graph_id}") - logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}") - for node_type, stats in self._stats[graph_id].nodes.items(): - logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G") - total_time += stats.time_used + graph_stats = self._stats[graph_id] + log = graph_stats.get_pretty_log(graph_id) cache_stats = self._cache_stats[graph_id] hwm = cache_stats.high_watermark / GIG tot = cache_stats.cache_size / GIG loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG - - logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") - logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)") - logger.info(f"RAM used to load models: {loaded:4.2f}G") + log += f"RAM used to load models: {loaded:4.2f}G\n" if torch.cuda.is_available(): - logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG)) - logger.info("RAM cache statistics:") - logger.info(f" Model cache hits: {cache_stats.hits}") - logger.info(f" Model cache misses: {cache_stats.misses}") - logger.info(f" Models cached: {cache_stats.in_cache}") - logger.info(f" Models cleared from cache: {cache_stats.cleared}") - logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G") + log += f"VRAM in use: {(torch.cuda.memory_allocated() / GIG):4.3f}G\n" + log += "RAM cache statistics:\n" + log += f" Model cache hits: {cache_stats.hits}\n" + log += f" Model cache misses: {cache_stats.misses}\n" + log += f" Models cached: {cache_stats.in_cache}\n" + log += f" Models cleared from cache: {cache_stats.cleared}\n" + log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n" + logger.info(log) completed.add(graph_id)