feat(stats): refactor InvocationStatsService to output stats as dataclasses

This allows the stats to be written to disk as JSON and analyzed.

- Add dataclasses to hold stats.
- Move stats pretty-print logic to `__str__` of the new `InvocationStatsSummary` class.
- Add `get_stats` and `dump_stats` methods to `InvocationStatsServiceBase`.
- `InvocationStatsService` now throws if stats are requested for a session it doesn't know about. This avoids needing to do a lot of messy null checks.
- Update `DefaultInvocationProcessor` to use the new stats methods and suppresses the new errors.
This commit is contained in:
psychedelicious 2024-01-31 21:31:38 +11:00
parent 25291a2e01
commit b24e8dd829
4 changed files with 233 additions and 53 deletions

View File

@ -1,11 +1,15 @@
import time import time
import traceback import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Event, Thread from threading import BoundedSemaphore, Event, Thread
from typing import Optional from typing import Optional
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
from invokeai.app.services.invocation_stats.invocation_stats_common import (
GESStatsNotFoundError,
)
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker from ..invoker import Invoker
@ -152,7 +156,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass pass
except CanceledException: except CanceledException:
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) with suppress(GESStatsNotFoundError):
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
pass pass
except Exception as e: except Exception as e:
@ -177,7 +182,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id) with suppress(GESStatsNotFoundError):
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
pass pass
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so
@ -209,15 +215,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error=traceback.format_exc(), error=traceback.format_exc(),
) )
elif is_complete: elif is_complete:
self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id) with suppress(GESStatsNotFoundError):
self.__invoker.services.events.emit_graph_execution_complete( self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id)
queue_batch_id=queue_item.session_queue_batch_id, self.__invoker.services.events.emit_graph_execution_complete(
queue_item_id=queue_item.session_queue_item_id, queue_batch_id=queue_item.session_queue_batch_id,
queue_id=queue_item.session_queue_id, queue_item_id=queue_item.session_queue_item_id,
graph_execution_state_id=graph_execution_state.id, queue_id=queue_item.session_queue_id,
) graph_execution_state_id=graph_execution_state.id,
if profiler: )
profiler.stop() if profiler:
profile_path = profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.__invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=graph_execution_state.id, output_path=stats_path
)
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
except KeyboardInterrupt: except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor

View File

@ -30,8 +30,10 @@ writes to the system log is stored in InvocationServices.performance_statistics.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from pathlib import Path
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
class InvocationStatsServiceBase(ABC): class InvocationStatsServiceBase(ABC):
@ -61,8 +63,9 @@ class InvocationStatsServiceBase(ABC):
@abstractmethod @abstractmethod
def reset_stats(self, graph_execution_state_id: str): def reset_stats(self, graph_execution_state_id: str):
""" """
Reset all statistics for the indicated graph Reset all statistics for the indicated graph.
:param graph_execution_state_id :param graph_execution_state_id: The id of the session whose stats to reset.
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
""" """
pass pass
@ -70,5 +73,26 @@ class InvocationStatsServiceBase(ABC):
def log_stats(self, graph_execution_state_id: str): def log_stats(self, graph_execution_state_id: str):
""" """
Write out the accumulated statistics to the log or somewhere else. Write out the accumulated statistics to the log or somewhere else.
:param graph_execution_state_id: The id of the session whose stats to log.
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
"""
pass
@abstractmethod
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
"""
Gets the accumulated statistics for the indicated graph.
:param graph_execution_state_id: The id of the session whose stats to get.
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
"""
pass
@abstractmethod
def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None:
"""
Write out the accumulated statistics to the indicated path as JSON.
:param graph_execution_state_id: The id of the session whose stats to dump.
:param output_path: The file to write the stats to.
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
""" """
pass pass

View File

@ -1,5 +1,91 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import asdict, dataclass
from typing import Any, Optional
class GESStatsNotFoundError(Exception):
"""Raised when execution stats are not found for a given Graph Execution State."""
@dataclass
class NodeExecutionStatsSummary:
"""The stats for a specific type of node."""
node_type: str
num_calls: int
time_used_seconds: float
peak_vram_gb: float
@dataclass
class ModelCacheStatsSummary:
"""The stats for the model cache."""
high_water_mark_gb: float
cache_size_gb: float
total_usage_gb: float
cache_hits: int
cache_misses: int
models_cached: int
models_cleared: int
@dataclass
class GraphExecutionStatsSummary:
"""The stats for the graph execution state."""
graph_execution_state_id: str
execution_time_seconds: float
# `wall_time_seconds`, `ram_usage_gb` and `ram_change_gb` are derived from the node execution stats.
# In some situations, there are no node stats, so these values are optional.
wall_time_seconds: Optional[float]
ram_usage_gb: Optional[float]
ram_change_gb: Optional[float]
@dataclass
class InvocationStatsSummary:
"""
The accumulated stats for a graph execution.
Its `__str__` method returns a human-readable stats summary.
"""
vram_usage_gb: Optional[float]
graph_stats: GraphExecutionStatsSummary
model_cache_stats: ModelCacheStatsSummary
node_stats: list[NodeExecutionStatsSummary]
def __str__(self) -> str:
_str = ""
_str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n"
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n"
for summary in self.node_stats:
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n"
_str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n"
if self.graph_stats.wall_time_seconds is not None:
_str += f"TOTAL GRAPH WALL TIME: {self.graph_stats.wall_time_seconds:7.3f}s\n"
if self.graph_stats.ram_usage_gb is not None and self.graph_stats.ram_change_gb is not None:
_str += f"RAM used by InvokeAI process: {self.graph_stats.ram_usage_gb:4.2f}G ({self.graph_stats.ram_change_gb:+5.3f}G)\n"
_str += f"RAM used to load models: {self.model_cache_stats.total_usage_gb:4.2f}G\n"
if self.vram_usage_gb:
_str += f"VRAM in use: {self.vram_usage_gb:4.3f}G\n"
_str += "RAM cache statistics:\n"
_str += f" Model cache hits: {self.model_cache_stats.cache_hits}\n"
_str += f" Model cache misses: {self.model_cache_stats.cache_misses}\n"
_str += f" Models cached: {self.model_cache_stats.models_cached}\n"
_str += f" Models cleared from cache: {self.model_cache_stats.models_cleared}\n"
_str += f" Cache high water mark: {self.model_cache_stats.high_water_mark_gb:4.2f}/{self.model_cache_stats.cache_size_gb:4.2f}G\n"
return _str
def as_dict(self) -> dict[str, Any]:
"""Returns the stats as a dictionary."""
return asdict(self)
@dataclass @dataclass
@ -55,12 +141,33 @@ class GraphExecutionStats:
return last_node return last_node
def get_pretty_log(self, graph_execution_state_id: str) -> str: def get_graph_stats_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
log = f"Graph stats: {graph_execution_state_id}\n" """Get a summary of the graph stats."""
log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n" first_node = self.get_first_node_stats()
last_node = self.get_last_node_stats()
# Log stats aggregated by node type. wall_time_seconds: Optional[float] = None
ram_usage_gb: Optional[float] = None
ram_change_gb: Optional[float] = None
if last_node and first_node:
wall_time_seconds = last_node.end_time - first_node.start_time
ram_usage_gb = last_node.end_ram_gb
ram_change_gb = last_node.end_ram_gb - first_node.start_ram_gb
return GraphExecutionStatsSummary(
graph_execution_state_id=graph_execution_state_id,
execution_time_seconds=self.get_total_run_time(),
wall_time_seconds=wall_time_seconds,
ram_usage_gb=ram_usage_gb,
ram_change_gb=ram_change_gb,
)
def get_node_stats_summaries(self) -> list[NodeExecutionStatsSummary]:
"""Get a summary of the node stats."""
summaries: list[NodeExecutionStatsSummary] = []
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list) node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
for node_stats in self._node_stats_list: for node_stats in self._node_stats_list:
node_stats_by_type[node_stats.invocation_type].append(node_stats) node_stats_by_type[node_stats.invocation_type].append(node_stats)
@ -68,17 +175,9 @@ class GraphExecutionStats:
num_calls = len(node_type_stats_list) num_calls = len(node_type_stats_list)
time_used = sum([n.total_time() for n in node_type_stats_list]) time_used = sum([n.total_time() for n in node_type_stats_list])
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list]) peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n" summary = NodeExecutionStatsSummary(
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram
)
summaries.append(summary)
# Log stats for the entire graph. return summaries
log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n"
first_node = self.get_first_node_stats()
last_node = self.get_last_node_stats()
if first_node is not None and last_node is not None:
total_wall_time = last_node.end_time - first_node.start_time
ram_change = last_node.end_ram_gb - first_node.start_ram_gb
log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n"
log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n"
return log

View File

@ -1,5 +1,7 @@
import json
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
import psutil import psutil
import torch import torch
@ -10,7 +12,15 @@ from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_management.model_cache import CacheStats from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_base import InvocationStatsServiceBase
from .invocation_stats_common import GraphExecutionStats, NodeExecutionStats from .invocation_stats_common import (
GESStatsNotFoundError,
GraphExecutionStats,
GraphExecutionStatsSummary,
InvocationStatsSummary,
ModelCacheStatsSummary,
NodeExecutionStats,
NodeExecutionStatsSummary,
)
# Size of 1GB in bytes. # Size of 1GB in bytes.
GB = 2**30 GB = 2**30
@ -95,31 +105,66 @@ class InvocationStatsService(InvocationStatsServiceBase):
del self._stats[graph_execution_state_id] del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id] del self._cache_stats[graph_execution_state_id]
except KeyError as e: except KeyError as e:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.") msg = f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
logger.warning(msg)
raise GESStatsNotFoundError(msg)
def log_stats(self, graph_execution_state_id: str): def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
node_stats_summaries = self._get_node_summaries(graph_execution_state_id)
model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id)
vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None
return InvocationStatsSummary(
graph_stats=graph_stats_summary,
model_cache_stats=model_cache_stats_summary,
node_stats=node_stats_summaries,
vram_usage_gb=vram_usage_gb,
)
def log_stats(self, graph_execution_state_id: str) -> None:
stats = self.get_stats(graph_execution_state_id)
logger.info(str(stats))
def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None:
stats = self.get_stats(graph_execution_state_id)
with open(output_path, "w") as f:
f.write(json.dumps(stats.as_dict(), indent=2))
def _get_model_cache_summary(self, graph_execution_state_id: str) -> ModelCacheStatsSummary:
try: try:
graph_stats = self._stats[graph_execution_state_id]
cache_stats = self._cache_stats[graph_execution_state_id] cache_stats = self._cache_stats[graph_execution_state_id]
except KeyError as e: except KeyError as e:
logger.warning(f"Attempted to log statistics for unknown graph {graph_execution_state_id}: {e}.") msg = f"Attempted to get model cache statistics for unknown graph {graph_execution_state_id}: {e}."
return logger.warning(msg)
raise GESStatsNotFoundError(msg)
log = graph_stats.get_pretty_log(graph_execution_state_id) return ModelCacheStatsSummary(
cache_hits=cache_stats.hits,
cache_misses=cache_stats.misses,
high_water_mark_gb=cache_stats.high_watermark / GB,
cache_size_gb=cache_stats.cache_size / GB,
total_usage_gb=sum(list(cache_stats.loaded_model_sizes.values())) / GB,
models_cached=cache_stats.in_cache,
models_cleared=cache_stats.cleared,
)
hwm = cache_stats.high_watermark / GB def _get_graph_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
tot = cache_stats.cache_size / GB try:
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB graph_stats = self._stats[graph_execution_state_id]
log += f"RAM used to load models: {loaded:4.2f}G\n" except KeyError as e:
if torch.cuda.is_available(): msg = f"Attempted to get graph statistics for unknown graph {graph_execution_state_id}: {e}."
log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n" logger.warning(msg)
log += "RAM cache statistics:\n" raise GESStatsNotFoundError(msg)
log += f" Model cache hits: {cache_stats.hits}\n"
log += f" Model cache misses: {cache_stats.misses}\n"
log += f" Models cached: {cache_stats.in_cache}\n"
log += f" Models cleared from cache: {cache_stats.cleared}\n"
log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
logger.info(log)
del self._stats[graph_execution_state_id] return graph_stats.get_graph_stats_summary(graph_execution_state_id)
del self._cache_stats[graph_execution_state_id]
def _get_node_summaries(self, graph_execution_state_id: str) -> list[NodeExecutionStatsSummary]:
try:
graph_stats = self._stats[graph_execution_state_id]
except KeyError as e:
msg = f"Attempted to get node statistics for unknown graph {graph_execution_state_id}: {e}."
logger.warning(msg)
raise GESStatsNotFoundError(msg)
return graph_stats.get_node_stats_summaries()