diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index 1cfb0c1822..4951b51121 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -1,11 +1,15 @@ import time import traceback +from contextlib import suppress from threading import BoundedSemaphore, Event, Thread from typing import Optional import invokeai.backend.util.logging as logger from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem +from invokeai.app.services.invocation_stats.invocation_stats_common import ( + GESStatsNotFoundError, +) from invokeai.app.util.profiler import Profiler from ..invoker import Invoker @@ -152,7 +156,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass except CanceledException: - self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) + with suppress(GESStatsNotFoundError): + self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) pass except Exception as e: @@ -177,7 +182,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=error, ) - self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) + with suppress(GESStatsNotFoundError): + self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) pass # Check queue to see if this is canceled, and skip if so @@ -209,15 +215,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error=traceback.format_exc(), ) elif is_complete: - self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id) - self.__invoker.services.events.emit_graph_execution_complete( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - ) - if profiler: - profiler.stop() + with suppress(GESStatsNotFoundError): + self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id) + self.__invoker.services.events.emit_graph_execution_complete( + queue_batch_id=queue_item.session_queue_batch_id, + queue_item_id=queue_item.session_queue_item_id, + queue_id=queue_item.session_queue_id, + graph_execution_state_id=graph_execution_state.id, + ) + if profiler: + profile_path = profiler.stop() + stats_path = profile_path.with_suffix(".json") + self.__invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=graph_execution_state.id, output_path=stats_path + ) + self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) except KeyboardInterrupt: pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index 6e5b6a9f69..22624a6579 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -30,8 +30,10 @@ writes to the system log is stored in InvocationServices.performance_statistics. from abc import ABC, abstractmethod from contextlib import AbstractContextManager +from pathlib import Path from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary class InvocationStatsServiceBase(ABC): @@ -61,8 +63,9 @@ class InvocationStatsServiceBase(ABC): @abstractmethod def reset_stats(self, graph_execution_state_id: str): """ - Reset all statistics for the indicated graph - :param graph_execution_state_id + Reset all statistics for the indicated graph. + :param graph_execution_state_id: The id of the session whose stats to reset. + :raises GESStatsNotFoundError: if the graph isn't tracked in the stats. """ pass @@ -70,5 +73,26 @@ class InvocationStatsServiceBase(ABC): def log_stats(self, graph_execution_state_id: str): """ Write out the accumulated statistics to the log or somewhere else. + :param graph_execution_state_id: The id of the session whose stats to log. + :raises GESStatsNotFoundError: if the graph isn't tracked in the stats. + """ + pass + + @abstractmethod + def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: + """ + Gets the accumulated statistics for the indicated graph. + :param graph_execution_state_id: The id of the session whose stats to get. + :raises GESStatsNotFoundError: if the graph isn't tracked in the stats. + """ + pass + + @abstractmethod + def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None: + """ + Write out the accumulated statistics to the indicated path as JSON. + :param graph_execution_state_id: The id of the session whose stats to dump. + :param output_path: The file to write the stats to. + :raises GESStatsNotFoundError: if the graph isn't tracked in the stats. """ pass diff --git a/invokeai/app/services/invocation_stats/invocation_stats_common.py b/invokeai/app/services/invocation_stats/invocation_stats_common.py index 543edc076a..f4c906a58f 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_common.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_common.py @@ -1,5 +1,91 @@ from collections import defaultdict -from dataclasses import dataclass +from dataclasses import asdict, dataclass +from typing import Any, Optional + + +class GESStatsNotFoundError(Exception): + """Raised when execution stats are not found for a given Graph Execution State.""" + + +@dataclass +class NodeExecutionStatsSummary: + """The stats for a specific type of node.""" + + node_type: str + num_calls: int + time_used_seconds: float + peak_vram_gb: float + + +@dataclass +class ModelCacheStatsSummary: + """The stats for the model cache.""" + + high_water_mark_gb: float + cache_size_gb: float + total_usage_gb: float + cache_hits: int + cache_misses: int + models_cached: int + models_cleared: int + + +@dataclass +class GraphExecutionStatsSummary: + """The stats for the graph execution state.""" + + graph_execution_state_id: str + execution_time_seconds: float + # `wall_time_seconds`, `ram_usage_gb` and `ram_change_gb` are derived from the node execution stats. + # In some situations, there are no node stats, so these values are optional. + wall_time_seconds: Optional[float] + ram_usage_gb: Optional[float] + ram_change_gb: Optional[float] + + +@dataclass +class InvocationStatsSummary: + """ + The accumulated stats for a graph execution. + Its `__str__` method returns a human-readable stats summary. + """ + + vram_usage_gb: Optional[float] + graph_stats: GraphExecutionStatsSummary + model_cache_stats: ModelCacheStatsSummary + node_stats: list[NodeExecutionStatsSummary] + + def __str__(self) -> str: + _str = "" + _str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n" + _str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n" + + for summary in self.node_stats: + _str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n" + + _str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n" + + if self.graph_stats.wall_time_seconds is not None: + _str += f"TOTAL GRAPH WALL TIME: {self.graph_stats.wall_time_seconds:7.3f}s\n" + + if self.graph_stats.ram_usage_gb is not None and self.graph_stats.ram_change_gb is not None: + _str += f"RAM used by InvokeAI process: {self.graph_stats.ram_usage_gb:4.2f}G ({self.graph_stats.ram_change_gb:+5.3f}G)\n" + + _str += f"RAM used to load models: {self.model_cache_stats.total_usage_gb:4.2f}G\n" + if self.vram_usage_gb: + _str += f"VRAM in use: {self.vram_usage_gb:4.3f}G\n" + _str += "RAM cache statistics:\n" + _str += f" Model cache hits: {self.model_cache_stats.cache_hits}\n" + _str += f" Model cache misses: {self.model_cache_stats.cache_misses}\n" + _str += f" Models cached: {self.model_cache_stats.models_cached}\n" + _str += f" Models cleared from cache: {self.model_cache_stats.models_cleared}\n" + _str += f" Cache high water mark: {self.model_cache_stats.high_water_mark_gb:4.2f}/{self.model_cache_stats.cache_size_gb:4.2f}G\n" + + return _str + + def as_dict(self) -> dict[str, Any]: + """Returns the stats as a dictionary.""" + return asdict(self) @dataclass @@ -55,12 +141,33 @@ class GraphExecutionStats: return last_node - def get_pretty_log(self, graph_execution_state_id: str) -> str: - log = f"Graph stats: {graph_execution_state_id}\n" - log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n" + def get_graph_stats_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary: + """Get a summary of the graph stats.""" + first_node = self.get_first_node_stats() + last_node = self.get_last_node_stats() - # Log stats aggregated by node type. + wall_time_seconds: Optional[float] = None + ram_usage_gb: Optional[float] = None + ram_change_gb: Optional[float] = None + + if last_node and first_node: + wall_time_seconds = last_node.end_time - first_node.start_time + ram_usage_gb = last_node.end_ram_gb + ram_change_gb = last_node.end_ram_gb - first_node.start_ram_gb + + return GraphExecutionStatsSummary( + graph_execution_state_id=graph_execution_state_id, + execution_time_seconds=self.get_total_run_time(), + wall_time_seconds=wall_time_seconds, + ram_usage_gb=ram_usage_gb, + ram_change_gb=ram_change_gb, + ) + + def get_node_stats_summaries(self) -> list[NodeExecutionStatsSummary]: + """Get a summary of the node stats.""" + summaries: list[NodeExecutionStatsSummary] = [] node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list) + for node_stats in self._node_stats_list: node_stats_by_type[node_stats.invocation_type].append(node_stats) @@ -68,17 +175,9 @@ class GraphExecutionStats: num_calls = len(node_type_stats_list) time_used = sum([n.total_time() for n in node_type_stats_list]) peak_vram = max([n.peak_vram_gb for n in node_type_stats_list]) - log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n" + summary = NodeExecutionStatsSummary( + node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram + ) + summaries.append(summary) - # Log stats for the entire graph. - log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n" - - first_node = self.get_first_node_stats() - last_node = self.get_last_node_stats() - if first_node is not None and last_node is not None: - total_wall_time = last_node.end_time - first_node.start_time - ram_change = last_node.end_ram_gb - first_node.start_ram_gb - log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n" - log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n" - - return log + return summaries diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 93f396c2b9..a2652dabc3 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -1,5 +1,7 @@ +import json import time from contextlib import contextmanager +from pathlib import Path import psutil import torch @@ -10,7 +12,15 @@ from invokeai.app.services.invoker import Invoker from invokeai.backend.model_management.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase -from .invocation_stats_common import GraphExecutionStats, NodeExecutionStats +from .invocation_stats_common import ( + GESStatsNotFoundError, + GraphExecutionStats, + GraphExecutionStatsSummary, + InvocationStatsSummary, + ModelCacheStatsSummary, + NodeExecutionStats, + NodeExecutionStatsSummary, +) # Size of 1GB in bytes. GB = 2**30 @@ -95,31 +105,66 @@ class InvocationStatsService(InvocationStatsServiceBase): del self._stats[graph_execution_state_id] del self._cache_stats[graph_execution_state_id] except KeyError as e: - logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.") + msg = f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}." + logger.warning(msg) + raise GESStatsNotFoundError(msg) - def log_stats(self, graph_execution_state_id: str): + def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: + graph_stats_summary = self._get_graph_summary(graph_execution_state_id) + node_stats_summaries = self._get_node_summaries(graph_execution_state_id) + model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id) + vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None + + return InvocationStatsSummary( + graph_stats=graph_stats_summary, + model_cache_stats=model_cache_stats_summary, + node_stats=node_stats_summaries, + vram_usage_gb=vram_usage_gb, + ) + + def log_stats(self, graph_execution_state_id: str) -> None: + stats = self.get_stats(graph_execution_state_id) + logger.info(str(stats)) + + def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None: + stats = self.get_stats(graph_execution_state_id) + with open(output_path, "w") as f: + f.write(json.dumps(stats.as_dict(), indent=2)) + + def _get_model_cache_summary(self, graph_execution_state_id: str) -> ModelCacheStatsSummary: try: - graph_stats = self._stats[graph_execution_state_id] cache_stats = self._cache_stats[graph_execution_state_id] except KeyError as e: - logger.warning(f"Attempted to log statistics for unknown graph {graph_execution_state_id}: {e}.") - return + msg = f"Attempted to get model cache statistics for unknown graph {graph_execution_state_id}: {e}." + logger.warning(msg) + raise GESStatsNotFoundError(msg) - log = graph_stats.get_pretty_log(graph_execution_state_id) + return ModelCacheStatsSummary( + cache_hits=cache_stats.hits, + cache_misses=cache_stats.misses, + high_water_mark_gb=cache_stats.high_watermark / GB, + cache_size_gb=cache_stats.cache_size / GB, + total_usage_gb=sum(list(cache_stats.loaded_model_sizes.values())) / GB, + models_cached=cache_stats.in_cache, + models_cleared=cache_stats.cleared, + ) - hwm = cache_stats.high_watermark / GB - tot = cache_stats.cache_size / GB - loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB - log += f"RAM used to load models: {loaded:4.2f}G\n" - if torch.cuda.is_available(): - log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n" - log += "RAM cache statistics:\n" - log += f" Model cache hits: {cache_stats.hits}\n" - log += f" Model cache misses: {cache_stats.misses}\n" - log += f" Models cached: {cache_stats.in_cache}\n" - log += f" Models cleared from cache: {cache_stats.cleared}\n" - log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n" - logger.info(log) + def _get_graph_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary: + try: + graph_stats = self._stats[graph_execution_state_id] + except KeyError as e: + msg = f"Attempted to get graph statistics for unknown graph {graph_execution_state_id}: {e}." + logger.warning(msg) + raise GESStatsNotFoundError(msg) - del self._stats[graph_execution_state_id] - del self._cache_stats[graph_execution_state_id] + return graph_stats.get_graph_stats_summary(graph_execution_state_id) + + def _get_node_summaries(self, graph_execution_state_id: str) -> list[NodeExecutionStatsSummary]: + try: + graph_stats = self._stats[graph_execution_state_id] + except KeyError as e: + msg = f"Attempted to get node statistics for unknown graph {graph_execution_state_id}: {e}." + logger.warning(msg) + raise GESStatsNotFoundError(msg) + + return graph_stats.get_node_stats_summaries()