Refactor the invocation stats service for better readability and to support reporting the execution wall time.

This commit is contained in:
Ryan Dick
2024-01-11 11:55:16 -05:00
committed by Kent Keirsey
parent c000e270a0
commit c8929b35f0
3 changed files with 122 additions and 166 deletions

View File

@ -30,23 +30,13 @@ writes to the system log is stored in InvocationServices.performance_statistics.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from typing import Dict
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_common import NodeLog
class InvocationStatsServiceBase(ABC): class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics" "Abstract base class for recording node memory/time performance statistics"
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod @abstractmethod
def __init__(self): def __init__(self):
""" """
@ -76,46 +66,9 @@ class InvocationStatsServiceBase(ABC):
""" """
pass pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod @abstractmethod
def log_stats(self): def log_stats(self):
""" """
Write out the accumulated statistics to the log or somewhere else. Write out the accumulated statistics to the log or somewhere else.
""" """
pass pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
pass

View File

@ -1,25 +1,87 @@
from dataclasses import dataclass, field from collections import defaultdict
from typing import Dict from dataclasses import dataclass
# size of GIG in bytes # size of GIG in bytes
GIG = 1073741824 GIG = 1073741824
@dataclass @dataclass
class NodeStats: class NodeExecutionStats:
"""Class for tracking execution stats of an invocation node""" """Class for tracking execution stats of an invocation node."""
calls: int = 0 invocation_type: str
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB start_time: float # Seconds since the epoch.
cache_hits: int = 0 end_time: float # Seconds since the epoch.
cache_misses: int = 0
cache_high_watermark: int = 0 start_ram_gb: float # GB
end_ram_gb: float # GB
peak_vram_gb: float # GB
def total_time(self) -> float:
return self.end_time - self.start_time
@dataclass class GraphExecutionStats:
class NodeLog: """Class for tracking execution stats of a graph."""
"""Class for tracking node usage"""
# {node_type => NodeStats} def __init__(self):
nodes: Dict[str, NodeStats] = field(default_factory=dict) self._node_stats_list: list[NodeExecutionStats] = []
def add_node_execution_stats(self, node_stats: NodeExecutionStats):
self._node_stats_list.append(node_stats)
def get_total_run_time(self) -> float:
"""Get the total time spent executing nodes in the graph."""
total = 0.0
for node_stats in self._node_stats_list:
total += node_stats.total_time()
return total
def get_first_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the first node in the graph (by start_time)."""
first_node = None
for node_stats in self._node_stats_list:
if first_node is None or node_stats.start_time < first_node.start_time:
first_node = node_stats
assert first_node is not None
return first_node
def get_last_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the last node in the graph (by end_time)."""
last_node = None
for node_stats in self._node_stats_list:
if last_node is None or node_stats.end_time > last_node.end_time:
last_node = node_stats
return last_node
def get_pretty_log(self, graph_execution_state_id: str) -> str:
log = f"Graph stats: {graph_execution_state_id}\n"
log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n"
# Log stats aggregated by node type.
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
for node_stats in self._node_stats_list:
node_stats_by_type[node_stats.invocation_type].append(node_stats)
for node_type, node_type_stats_list in node_stats_by_type.items():
num_calls = len(node_type_stats_list)
time_used = sum([n.total_time() for n in node_type_stats_list])
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n"
# Log stats for the entire graph.
log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n"
first_node = self.get_first_node_stats()
last_node = self.get_last_node_stats()
if first_node is not None and last_node is not None:
total_wall_time = last_node.end_time - first_node.start_time
ram_change = last_node.end_ram_gb - first_node.start_ram_gb
log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n"
log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n"
return log

View File

@ -1,5 +1,5 @@
import time import time
from typing import Dict from contextlib import contextmanager
import psutil import psutil
import torch import torch
@ -7,85 +7,54 @@ import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.backend.model_management.model_cache import CacheStats from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_base import InvocationStatsServiceBase
from .invocation_stats_common import GIG, NodeLog, NodeStats from .invocation_stats_common import GIG, GraphExecutionStats, NodeExecutionStats
class InvocationStatsService(InvocationStatsServiceBase): class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node, """Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems""" as well as the maximum and current VRAM utilisation for CUDA systems"""
_invoker: Invoker
def __init__(self): def __init__(self):
# {graph_id => NodeLog} # Maps graph_execution_state_id to GraphExecutionStats.
self._stats: Dict[str, NodeLog] = {} self._stats: dict[str, GraphExecutionStats] = {}
self._cache_stats: Dict[str, CacheStats] = {} # Maps graph_execution_state_id to model manager CacheStats.
self.ram_used: float = 0.0 self._cache_stats: dict[str, CacheStats] = {}
self.ram_changed: float = 0.0
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self._invoker = invoker self._invoker = invoker
class StatsContext: @contextmanager
"""Context manager for collecting statistics.""" def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
if not self._stats.get(graph_execution_state_id):
invocation: BaseInvocation # First time we're seeing this graph_execution_state_id.
collector: "InvocationStatsServiceBase" self._stats[graph_execution_state_id] = GraphExecutionStats()
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerServiceBase
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerServiceBase,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0.0
self.ram_used = 0
self.model_manager = model_manager
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_mem_stats(
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> StatsContext:
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats() self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
# Record state before the invocation.
start_time = time.time()
start_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if self._invoker.services.model_manager:
self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id])
try:
# Let the invocation run.
yield None
finally:
# Record state after the invocation.
node_stats = NodeExecutionStats(
invocation_type=invocation.type,
start_time=start_time,
end_time=time.time(),
start_ram_gb=start_ram / GIG,
end_ram_gb=psutil.Process().memory_info().rss / GIG,
peak_vram_gb=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_all_stats(self): def reset_all_stats(self):
"""Zero all statistics""" """Zero all statistics"""
@ -97,28 +66,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
except KeyError: except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
self.ram_used = ram_used
self.ram_changed = ram_changed
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self): def log_stats(self):
completed = set() completed = set()
errored = set() errored = set()
@ -132,29 +79,23 @@ class InvocationStatsService(InvocationStatsServiceBase):
if not current_graph_state.is_complete(): if not current_graph_state.is_complete():
continue continue
total_time = 0 graph_stats = self._stats[graph_id]
logger.info(f"Graph stats: {graph_id}") log = graph_stats.get_pretty_log(graph_id)
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
total_time += stats.time_used
cache_stats = self._cache_stats[graph_id] cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG tot = cache_stats.cache_size / GIG
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
log += f"RAM used to load models: {loaded:4.2f}G\n"
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
logger.info(f"RAM used to load models: {loaded:4.2f}G")
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG)) log += f"VRAM in use: {(torch.cuda.memory_allocated() / GIG):4.3f}G\n"
logger.info("RAM cache statistics:") log += "RAM cache statistics:\n"
logger.info(f" Model cache hits: {cache_stats.hits}") log += f" Model cache hits: {cache_stats.hits}\n"
logger.info(f" Model cache misses: {cache_stats.misses}") log += f" Model cache misses: {cache_stats.misses}\n"
logger.info(f" Models cached: {cache_stats.in_cache}") log += f" Models cached: {cache_stats.in_cache}\n"
logger.info(f" Models cleared from cache: {cache_stats.cleared}") log += f" Models cleared from cache: {cache_stats.cleared}\n"
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G") log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
logger.info(log)
completed.add(graph_id) completed.add(graph_id)