mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cd9baf8092
- move docstrings to ABC - `start_time: int` -> `start_time: float` - remove class attribute assignments in `StatsContext` - add `update_mem_stats()` to ABC - add class attributes to ABC, because they are referenced in instances of the class. if they should not be on the ABC, then maybe there needs to be some restructuring
302 lines
11 KiB
Python
302 lines
11 KiB
Python
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
|
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
|
|
|
Usage:
|
|
|
|
statistics = InvocationStatsService(graph_execution_manager)
|
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
|
... execute graphs...
|
|
statistics.log_stats()
|
|
|
|
Typical output:
|
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
|
|
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
|
"""
|
|
|
|
import psutil
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from contextlib import AbstractContextManager
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict
|
|
from pydantic import ValidationError
|
|
|
|
import torch
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
|
|
from ..invocations.baseinvocation import BaseInvocation
|
|
from .graph import GraphExecutionState
|
|
from .item_storage import ItemStorageABC
|
|
from .model_manager_service import ModelManagerService
|
|
from invokeai.backend.model_management.model_cache import CacheStats
|
|
|
|
# size of GIG in bytes
|
|
GIG = 1073741824
|
|
|
|
|
|
@dataclass
|
|
class NodeStats:
|
|
"""Class for tracking execution stats of an invocation node"""
|
|
|
|
calls: int = 0
|
|
time_used: float = 0.0 # seconds
|
|
max_vram: float = 0.0 # GB
|
|
cache_hits: int = 0
|
|
cache_misses: int = 0
|
|
cache_high_watermark: int = 0
|
|
|
|
|
|
@dataclass
|
|
class NodeLog:
|
|
"""Class for tracking node usage"""
|
|
|
|
# {node_type => NodeStats}
|
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
|
|
|
|
class InvocationStatsServiceBase(ABC):
|
|
"Abstract base class for recording node memory/time performance statistics"
|
|
|
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
|
# {graph_id => NodeLog}
|
|
_stats: Dict[str, NodeLog]
|
|
_cache_stats: Dict[str, CacheStats]
|
|
ram_used: float
|
|
ram_changed: float
|
|
|
|
@abstractmethod
|
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
|
"""
|
|
Initialize the InvocationStatsService and reset counters to zero
|
|
:param graph_execution_manager: Graph execution manager for this session
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def collect_stats(
|
|
self,
|
|
invocation: BaseInvocation,
|
|
graph_execution_state_id: str,
|
|
) -> AbstractContextManager:
|
|
"""
|
|
Return a context object that will capture the statistics on the execution
|
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
|
:param invocation: BaseInvocation object from the current graph.
|
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def reset_stats(self, graph_execution_state_id: str):
|
|
"""
|
|
Reset all statistics for the indicated graph
|
|
:param graph_execution_state_id
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def reset_all_stats(self):
|
|
"""Zero all statistics"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_invocation_stats(
|
|
self,
|
|
graph_id: str,
|
|
invocation_type: str,
|
|
time_used: float,
|
|
vram_used: float,
|
|
):
|
|
"""
|
|
Add timing information on execution of a node. Usually
|
|
used internally.
|
|
:param graph_id: ID of the graph that is currently executing
|
|
:param invocation_type: String literal type of the node
|
|
:param time_used: Time used by node's exection (sec)
|
|
:param vram_used: Maximum VRAM used during exection (GB)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def log_stats(self):
|
|
"""
|
|
Write out the accumulated statistics to the log or somewhere else.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_mem_stats(
|
|
self,
|
|
ram_used: float,
|
|
ram_changed: float,
|
|
):
|
|
"""
|
|
Update the collector with RAM memory usage info.
|
|
|
|
:param ram_used: How much RAM is currently in use.
|
|
:param ram_changed: How much RAM changed since last generation.
|
|
"""
|
|
pass
|
|
|
|
|
|
class InvocationStatsService(InvocationStatsServiceBase):
|
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
|
|
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
|
self.graph_execution_manager = graph_execution_manager
|
|
# {graph_id => NodeLog}
|
|
self._stats: Dict[str, NodeLog] = {}
|
|
self._cache_stats: Dict[str, CacheStats] = {}
|
|
self.ram_used: float = 0.0
|
|
self.ram_changed: float = 0.0
|
|
|
|
class StatsContext:
|
|
"""Context manager for collecting statistics."""
|
|
|
|
invocation: BaseInvocation
|
|
collector: "InvocationStatsServiceBase"
|
|
graph_id: str
|
|
start_time: float
|
|
ram_used: int
|
|
model_manager: ModelManagerService
|
|
|
|
def __init__(
|
|
self,
|
|
invocation: BaseInvocation,
|
|
graph_id: str,
|
|
model_manager: ModelManagerService,
|
|
collector: "InvocationStatsServiceBase",
|
|
):
|
|
"""Initialize statistics for this run."""
|
|
self.invocation = invocation
|
|
self.collector = collector
|
|
self.graph_id = graph_id
|
|
self.start_time = 0.0
|
|
self.ram_used = 0
|
|
self.model_manager = model_manager
|
|
|
|
def __enter__(self):
|
|
self.start_time = time.time()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
self.ram_used = psutil.Process().memory_info().rss
|
|
if self.model_manager:
|
|
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
|
|
|
def __exit__(self, *args):
|
|
"""Called on exit from the context."""
|
|
ram_used = psutil.Process().memory_info().rss
|
|
self.collector.update_mem_stats(
|
|
ram_used=ram_used / GIG,
|
|
ram_changed=(ram_used - self.ram_used) / GIG,
|
|
)
|
|
self.collector.update_invocation_stats(
|
|
graph_id=self.graph_id,
|
|
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
|
time_used=time.time() - self.start_time,
|
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
|
)
|
|
|
|
def collect_stats(
|
|
self,
|
|
invocation: BaseInvocation,
|
|
graph_execution_state_id: str,
|
|
model_manager: ModelManagerService,
|
|
) -> StatsContext:
|
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
|
self._stats[graph_execution_state_id] = NodeLog()
|
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
|
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
|
|
|
def reset_all_stats(self):
|
|
"""Zero all statistics"""
|
|
self._stats = {}
|
|
|
|
def reset_stats(self, graph_execution_id: str):
|
|
try:
|
|
self._stats.pop(graph_execution_id)
|
|
except KeyError:
|
|
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
|
|
|
def update_mem_stats(
|
|
self,
|
|
ram_used: float,
|
|
ram_changed: float,
|
|
):
|
|
self.ram_used = ram_used
|
|
self.ram_changed = ram_changed
|
|
|
|
def update_invocation_stats(
|
|
self,
|
|
graph_id: str,
|
|
invocation_type: str,
|
|
time_used: float,
|
|
vram_used: float,
|
|
):
|
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
|
stats = self._stats[graph_id].nodes[invocation_type]
|
|
stats.calls += 1
|
|
stats.time_used += time_used
|
|
stats.max_vram = max(stats.max_vram, vram_used)
|
|
|
|
def log_stats(self):
|
|
completed = set()
|
|
for graph_id, node_log in self._stats.items():
|
|
try:
|
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
|
except ValidationError:
|
|
del self._stats[graph_id]
|
|
del self._cache_stats[graph_id]
|
|
continue
|
|
|
|
if not current_graph_state.is_complete():
|
|
continue
|
|
|
|
total_time = 0
|
|
logger.info(f"Graph stats: {graph_id}")
|
|
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
|
|
for node_type, stats in self._stats[graph_id].nodes.items():
|
|
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
|
|
total_time += stats.time_used
|
|
|
|
cache_stats = self._cache_stats[graph_id]
|
|
hwm = cache_stats.high_watermark / GIG
|
|
tot = cache_stats.cache_size / GIG
|
|
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
|
|
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
|
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
|
logger.info(f"RAM used to load models: {loaded:4.2f}G")
|
|
if torch.cuda.is_available():
|
|
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
|
|
logger.info("RAM cache statistics:")
|
|
logger.info(f" Model cache hits: {cache_stats.hits}")
|
|
logger.info(f" Model cache misses: {cache_stats.misses}")
|
|
logger.info(f" Models cached: {cache_stats.in_cache}")
|
|
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
|
|
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
|
|
|
|
completed.add(graph_id)
|
|
|
|
for graph_id in completed:
|
|
del self._stats[graph_id]
|
|
del self._cache_stats[graph_id]
|