mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
81385d7d35
When retrieving a graph, it is parsed through pydantic. It is possible that this graph is invalid, and an error is thrown. Handle this by deleting the failed graph from the stats if this occurs.
312 lines
12 KiB
Python
312 lines
12 KiB
Python
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
|
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
|
|
|
Usage:
|
|
|
|
statistics = InvocationStatsService(graph_execution_manager)
|
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
|
... execute graphs...
|
|
statistics.log_stats()
|
|
|
|
Typical output:
|
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
|
|
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
|
"""
|
|
|
|
import psutil
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from contextlib import AbstractContextManager
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict
|
|
from pydantic import ValidationError
|
|
|
|
import torch
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
|
|
from ..invocations.baseinvocation import BaseInvocation
|
|
from .graph import GraphExecutionState
|
|
from .item_storage import ItemStorageABC
|
|
from .model_manager_service import ModelManagerService
|
|
from invokeai.backend.model_management.model_cache import CacheStats
|
|
|
|
# size of GIG in bytes
|
|
GIG = 1073741824
|
|
|
|
|
|
class InvocationStatsServiceBase(ABC):
|
|
"Abstract base class for recording node memory/time performance statistics"
|
|
|
|
@abstractmethod
|
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
|
"""
|
|
Initialize the InvocationStatsService and reset counters to zero
|
|
:param graph_execution_manager: Graph execution manager for this session
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def collect_stats(
|
|
self,
|
|
invocation: BaseInvocation,
|
|
graph_execution_state_id: str,
|
|
) -> AbstractContextManager:
|
|
"""
|
|
Return a context object that will capture the statistics on the execution
|
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
|
:param invocation: BaseInvocation object from the current graph.
|
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def reset_stats(self, graph_execution_state_id: str):
|
|
"""
|
|
Reset all statistics for the indicated graph
|
|
:param graph_execution_state_id
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def reset_all_stats(self):
|
|
"""Zero all statistics"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_invocation_stats(
|
|
self,
|
|
graph_id: str,
|
|
invocation_type: str,
|
|
time_used: float,
|
|
vram_used: float,
|
|
ram_used: float,
|
|
ram_changed: float,
|
|
):
|
|
"""
|
|
Add timing information on execution of a node. Usually
|
|
used internally.
|
|
:param graph_id: ID of the graph that is currently executing
|
|
:param invocation_type: String literal type of the node
|
|
:param time_used: Time used by node's exection (sec)
|
|
:param vram_used: Maximum VRAM used during exection (GB)
|
|
:param ram_used: Current RAM available (GB)
|
|
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def log_stats(self):
|
|
"""
|
|
Write out the accumulated statistics to the log or somewhere else.
|
|
"""
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class NodeStats:
|
|
"""Class for tracking execution stats of an invocation node"""
|
|
|
|
calls: int = 0
|
|
time_used: float = 0.0 # seconds
|
|
max_vram: float = 0.0 # GB
|
|
cache_hits: int = 0
|
|
cache_misses: int = 0
|
|
cache_high_watermark: int = 0
|
|
|
|
|
|
@dataclass
|
|
class NodeLog:
|
|
"""Class for tracking node usage"""
|
|
|
|
# {node_type => NodeStats}
|
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
|
|
|
|
class InvocationStatsService(InvocationStatsServiceBase):
|
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
|
|
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
|
self.graph_execution_manager = graph_execution_manager
|
|
# {graph_id => NodeLog}
|
|
self._stats: Dict[str, NodeLog] = {}
|
|
self._cache_stats: Dict[str, CacheStats] = {}
|
|
self.ram_used: float = 0.0
|
|
self.ram_changed: float = 0.0
|
|
|
|
class StatsContext:
|
|
"""Context manager for collecting statistics."""
|
|
|
|
invocation: BaseInvocation = None
|
|
collector: "InvocationStatsServiceBase" = None
|
|
graph_id: str = None
|
|
start_time: int = 0
|
|
ram_used: int = 0
|
|
model_manager: ModelManagerService = None
|
|
|
|
def __init__(
|
|
self,
|
|
invocation: BaseInvocation,
|
|
graph_id: str,
|
|
model_manager: ModelManagerService,
|
|
collector: "InvocationStatsServiceBase",
|
|
):
|
|
"""Initialize statistics for this run."""
|
|
self.invocation = invocation
|
|
self.collector = collector
|
|
self.graph_id = graph_id
|
|
self.start_time = 0
|
|
self.ram_used = 0
|
|
self.model_manager = model_manager
|
|
|
|
def __enter__(self):
|
|
self.start_time = time.time()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
self.ram_used = psutil.Process().memory_info().rss
|
|
if self.model_manager:
|
|
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
|
|
|
def __exit__(self, *args):
|
|
"""Called on exit from the context."""
|
|
ram_used = psutil.Process().memory_info().rss
|
|
self.collector.update_mem_stats(
|
|
ram_used=ram_used / GIG,
|
|
ram_changed=(ram_used - self.ram_used) / GIG,
|
|
)
|
|
self.collector.update_invocation_stats(
|
|
graph_id=self.graph_id,
|
|
invocation_type=self.invocation.type,
|
|
time_used=time.time() - self.start_time,
|
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
|
)
|
|
|
|
def collect_stats(
|
|
self,
|
|
invocation: BaseInvocation,
|
|
graph_execution_state_id: str,
|
|
model_manager: ModelManagerService,
|
|
) -> StatsContext:
|
|
"""
|
|
Return a context object that will capture the statistics.
|
|
:param invocation: BaseInvocation object from the current graph.
|
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
"""
|
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
|
self._stats[graph_execution_state_id] = NodeLog()
|
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
|
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
|
|
|
def reset_all_stats(self):
|
|
"""Zero all statistics"""
|
|
self._stats = {}
|
|
|
|
def reset_stats(self, graph_execution_id: str):
|
|
"""Zero the statistics for the indicated graph."""
|
|
try:
|
|
self._stats.pop(graph_execution_id)
|
|
except KeyError:
|
|
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
|
|
|
def update_mem_stats(
|
|
self,
|
|
ram_used: float,
|
|
ram_changed: float,
|
|
):
|
|
"""
|
|
Update the collector with RAM memory usage info.
|
|
|
|
:param ram_used: How much RAM is currently in use.
|
|
:param ram_changed: How much RAM changed since last generation.
|
|
"""
|
|
self.ram_used = ram_used
|
|
self.ram_changed = ram_changed
|
|
|
|
def update_invocation_stats(
|
|
self,
|
|
graph_id: str,
|
|
invocation_type: str,
|
|
time_used: float,
|
|
vram_used: float,
|
|
):
|
|
"""
|
|
Add timing information on execution of a node. Usually
|
|
used internally.
|
|
:param graph_id: ID of the graph that is currently executing
|
|
:param invocation_type: String literal type of the node
|
|
:param time_used: Time used by node's exection (sec)
|
|
:param vram_used: Maximum VRAM used during exection (GB)
|
|
:param ram_used: Current RAM available (GB)
|
|
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
"""
|
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
|
stats = self._stats[graph_id].nodes[invocation_type]
|
|
stats.calls += 1
|
|
stats.time_used += time_used
|
|
stats.max_vram = max(stats.max_vram, vram_used)
|
|
|
|
def log_stats(self):
|
|
"""
|
|
Send the statistics to the system logger at the info level.
|
|
Stats will only be printed when the execution of the graph
|
|
is complete.
|
|
"""
|
|
completed = set()
|
|
for graph_id, node_log in self._stats.items():
|
|
try:
|
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
|
except ValidationError:
|
|
del self._stats[graph_id]
|
|
del self._cache_stats[graph_id]
|
|
continue
|
|
|
|
if not current_graph_state.is_complete():
|
|
continue
|
|
|
|
total_time = 0
|
|
logger.info(f"Graph stats: {graph_id}")
|
|
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
|
|
for node_type, stats in self._stats[graph_id].nodes.items():
|
|
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
|
|
total_time += stats.time_used
|
|
|
|
cache_stats = self._cache_stats[graph_id]
|
|
hwm = cache_stats.high_watermark / GIG
|
|
tot = cache_stats.cache_size / GIG
|
|
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
|
|
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
|
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
|
logger.info(f"RAM used to load models: {loaded:4.2f}G")
|
|
if torch.cuda.is_available():
|
|
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
|
|
logger.info("RAM cache statistics:")
|
|
logger.info(f" Model cache hits: {cache_stats.hits}")
|
|
logger.info(f" Model cache misses: {cache_stats.misses}")
|
|
logger.info(f" Models cached: {cache_stats.in_cache}")
|
|
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
|
|
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
|
|
|
|
completed.add(graph_id)
|
|
|
|
for graph_id in completed:
|
|
del self._stats[graph_id]
|
|
del self._cache_stats[graph_id]
|