mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(stats): refactor InvocationStatsService to output stats as dataclasses
This allows the stats to be written to disk as JSON and analyzed. - Add dataclasses to hold stats. - Move stats pretty-print logic to `__str__` of the new `InvocationStatsSummary` class. - Add `get_stats` and `dump_stats` methods to `InvocationStatsServiceBase`. - `InvocationStatsService` now throws if stats are requested for a session it doesn't know about. This avoids needing to do a lot of messy null checks. - Update `DefaultInvocationProcessor` to use the new stats methods and suppresses the new errors.
This commit is contained in:
parent
25291a2e01
commit
b24e8dd829
@ -1,11 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from contextlib import suppress
|
||||||
from threading import BoundedSemaphore, Event, Thread
|
from threading import BoundedSemaphore, Event, Thread
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||||
|
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||||
|
GESStatsNotFoundError,
|
||||||
|
)
|
||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
@ -152,7 +156,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
with suppress(GESStatsNotFoundError):
|
||||||
|
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -177,7 +182,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
with suppress(GESStatsNotFoundError):
|
||||||
|
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
@ -209,15 +215,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id)
|
with suppress(GESStatsNotFoundError):
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(
|
self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id)
|
||||||
queue_batch_id=queue_item.session_queue_batch_id,
|
self.__invoker.services.events.emit_graph_execution_complete(
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
queue_id=queue_item.session_queue_id,
|
||||||
)
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
if profiler:
|
)
|
||||||
profiler.stop()
|
if profiler:
|
||||||
|
profile_path = profiler.stop()
|
||||||
|
stats_path = profile_path.with_suffix(".json")
|
||||||
|
self.__invoker.services.performance_statistics.dump_stats(
|
||||||
|
graph_execution_state_id=graph_execution_state.id, output_path=stats_path
|
||||||
|
)
|
||||||
|
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||||
|
@ -30,8 +30,10 @@ writes to the system log is stored in InvocationServices.performance_statistics.
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
class InvocationStatsServiceBase(ABC):
|
||||||
@ -61,8 +63,9 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset_stats(self, graph_execution_state_id: str):
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
"""
|
"""
|
||||||
Reset all statistics for the indicated graph
|
Reset all statistics for the indicated graph.
|
||||||
:param graph_execution_state_id
|
:param graph_execution_state_id: The id of the session whose stats to reset.
|
||||||
|
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -70,5 +73,26 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
def log_stats(self, graph_execution_state_id: str):
|
def log_stats(self, graph_execution_state_id: str):
|
||||||
"""
|
"""
|
||||||
Write out the accumulated statistics to the log or somewhere else.
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
|
:param graph_execution_state_id: The id of the session whose stats to log.
|
||||||
|
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||||
|
"""
|
||||||
|
Gets the accumulated statistics for the indicated graph.
|
||||||
|
:param graph_execution_state_id: The id of the session whose stats to get.
|
||||||
|
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Write out the accumulated statistics to the indicated path as JSON.
|
||||||
|
:param graph_execution_state_id: The id of the session whose stats to dump.
|
||||||
|
:param output_path: The file to write the stats to.
|
||||||
|
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -1,5 +1,91 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class GESStatsNotFoundError(Exception):
|
||||||
|
"""Raised when execution stats are not found for a given Graph Execution State."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeExecutionStatsSummary:
|
||||||
|
"""The stats for a specific type of node."""
|
||||||
|
|
||||||
|
node_type: str
|
||||||
|
num_calls: int
|
||||||
|
time_used_seconds: float
|
||||||
|
peak_vram_gb: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCacheStatsSummary:
|
||||||
|
"""The stats for the model cache."""
|
||||||
|
|
||||||
|
high_water_mark_gb: float
|
||||||
|
cache_size_gb: float
|
||||||
|
total_usage_gb: float
|
||||||
|
cache_hits: int
|
||||||
|
cache_misses: int
|
||||||
|
models_cached: int
|
||||||
|
models_cleared: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraphExecutionStatsSummary:
|
||||||
|
"""The stats for the graph execution state."""
|
||||||
|
|
||||||
|
graph_execution_state_id: str
|
||||||
|
execution_time_seconds: float
|
||||||
|
# `wall_time_seconds`, `ram_usage_gb` and `ram_change_gb` are derived from the node execution stats.
|
||||||
|
# In some situations, there are no node stats, so these values are optional.
|
||||||
|
wall_time_seconds: Optional[float]
|
||||||
|
ram_usage_gb: Optional[float]
|
||||||
|
ram_change_gb: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvocationStatsSummary:
|
||||||
|
"""
|
||||||
|
The accumulated stats for a graph execution.
|
||||||
|
Its `__str__` method returns a human-readable stats summary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vram_usage_gb: Optional[float]
|
||||||
|
graph_stats: GraphExecutionStatsSummary
|
||||||
|
model_cache_stats: ModelCacheStatsSummary
|
||||||
|
node_stats: list[NodeExecutionStatsSummary]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
_str = ""
|
||||||
|
_str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n"
|
||||||
|
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n"
|
||||||
|
|
||||||
|
for summary in self.node_stats:
|
||||||
|
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n"
|
||||||
|
|
||||||
|
_str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n"
|
||||||
|
|
||||||
|
if self.graph_stats.wall_time_seconds is not None:
|
||||||
|
_str += f"TOTAL GRAPH WALL TIME: {self.graph_stats.wall_time_seconds:7.3f}s\n"
|
||||||
|
|
||||||
|
if self.graph_stats.ram_usage_gb is not None and self.graph_stats.ram_change_gb is not None:
|
||||||
|
_str += f"RAM used by InvokeAI process: {self.graph_stats.ram_usage_gb:4.2f}G ({self.graph_stats.ram_change_gb:+5.3f}G)\n"
|
||||||
|
|
||||||
|
_str += f"RAM used to load models: {self.model_cache_stats.total_usage_gb:4.2f}G\n"
|
||||||
|
if self.vram_usage_gb:
|
||||||
|
_str += f"VRAM in use: {self.vram_usage_gb:4.3f}G\n"
|
||||||
|
_str += "RAM cache statistics:\n"
|
||||||
|
_str += f" Model cache hits: {self.model_cache_stats.cache_hits}\n"
|
||||||
|
_str += f" Model cache misses: {self.model_cache_stats.cache_misses}\n"
|
||||||
|
_str += f" Models cached: {self.model_cache_stats.models_cached}\n"
|
||||||
|
_str += f" Models cleared from cache: {self.model_cache_stats.models_cleared}\n"
|
||||||
|
_str += f" Cache high water mark: {self.model_cache_stats.high_water_mark_gb:4.2f}/{self.model_cache_stats.cache_size_gb:4.2f}G\n"
|
||||||
|
|
||||||
|
return _str
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
"""Returns the stats as a dictionary."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -55,12 +141,33 @@ class GraphExecutionStats:
|
|||||||
|
|
||||||
return last_node
|
return last_node
|
||||||
|
|
||||||
def get_pretty_log(self, graph_execution_state_id: str) -> str:
|
def get_graph_stats_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
|
||||||
log = f"Graph stats: {graph_execution_state_id}\n"
|
"""Get a summary of the graph stats."""
|
||||||
log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n"
|
first_node = self.get_first_node_stats()
|
||||||
|
last_node = self.get_last_node_stats()
|
||||||
|
|
||||||
# Log stats aggregated by node type.
|
wall_time_seconds: Optional[float] = None
|
||||||
|
ram_usage_gb: Optional[float] = None
|
||||||
|
ram_change_gb: Optional[float] = None
|
||||||
|
|
||||||
|
if last_node and first_node:
|
||||||
|
wall_time_seconds = last_node.end_time - first_node.start_time
|
||||||
|
ram_usage_gb = last_node.end_ram_gb
|
||||||
|
ram_change_gb = last_node.end_ram_gb - first_node.start_ram_gb
|
||||||
|
|
||||||
|
return GraphExecutionStatsSummary(
|
||||||
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
|
execution_time_seconds=self.get_total_run_time(),
|
||||||
|
wall_time_seconds=wall_time_seconds,
|
||||||
|
ram_usage_gb=ram_usage_gb,
|
||||||
|
ram_change_gb=ram_change_gb,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_node_stats_summaries(self) -> list[NodeExecutionStatsSummary]:
|
||||||
|
"""Get a summary of the node stats."""
|
||||||
|
summaries: list[NodeExecutionStatsSummary] = []
|
||||||
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
|
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
|
||||||
|
|
||||||
for node_stats in self._node_stats_list:
|
for node_stats in self._node_stats_list:
|
||||||
node_stats_by_type[node_stats.invocation_type].append(node_stats)
|
node_stats_by_type[node_stats.invocation_type].append(node_stats)
|
||||||
|
|
||||||
@ -68,17 +175,9 @@ class GraphExecutionStats:
|
|||||||
num_calls = len(node_type_stats_list)
|
num_calls = len(node_type_stats_list)
|
||||||
time_used = sum([n.total_time() for n in node_type_stats_list])
|
time_used = sum([n.total_time() for n in node_type_stats_list])
|
||||||
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
|
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
|
||||||
log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n"
|
summary = NodeExecutionStatsSummary(
|
||||||
|
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram
|
||||||
|
)
|
||||||
|
summaries.append(summary)
|
||||||
|
|
||||||
# Log stats for the entire graph.
|
return summaries
|
||||||
log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n"
|
|
||||||
|
|
||||||
first_node = self.get_first_node_stats()
|
|
||||||
last_node = self.get_last_node_stats()
|
|
||||||
if first_node is not None and last_node is not None:
|
|
||||||
total_wall_time = last_node.end_time - first_node.start_time
|
|
||||||
ram_change = last_node.end_ram_gb - first_node.start_ram_gb
|
|
||||||
log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n"
|
|
||||||
log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n"
|
|
||||||
|
|
||||||
return log
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@ -10,7 +12,15 @@ from invokeai.app.services.invoker import Invoker
|
|||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
from .invocation_stats_base import InvocationStatsServiceBase
|
from .invocation_stats_base import InvocationStatsServiceBase
|
||||||
from .invocation_stats_common import GraphExecutionStats, NodeExecutionStats
|
from .invocation_stats_common import (
|
||||||
|
GESStatsNotFoundError,
|
||||||
|
GraphExecutionStats,
|
||||||
|
GraphExecutionStatsSummary,
|
||||||
|
InvocationStatsSummary,
|
||||||
|
ModelCacheStatsSummary,
|
||||||
|
NodeExecutionStats,
|
||||||
|
NodeExecutionStatsSummary,
|
||||||
|
)
|
||||||
|
|
||||||
# Size of 1GB in bytes.
|
# Size of 1GB in bytes.
|
||||||
GB = 2**30
|
GB = 2**30
|
||||||
@ -95,31 +105,66 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
del self._stats[graph_execution_state_id]
|
del self._stats[graph_execution_state_id]
|
||||||
del self._cache_stats[graph_execution_state_id]
|
del self._cache_stats[graph_execution_state_id]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.")
|
msg = f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||||
|
logger.warning(msg)
|
||||||
|
raise GESStatsNotFoundError(msg)
|
||||||
|
|
||||||
def log_stats(self, graph_execution_state_id: str):
|
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||||
|
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||||
|
node_stats_summaries = self._get_node_summaries(graph_execution_state_id)
|
||||||
|
model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id)
|
||||||
|
vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None
|
||||||
|
|
||||||
|
return InvocationStatsSummary(
|
||||||
|
graph_stats=graph_stats_summary,
|
||||||
|
model_cache_stats=model_cache_stats_summary,
|
||||||
|
node_stats=node_stats_summaries,
|
||||||
|
vram_usage_gb=vram_usage_gb,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_stats(self, graph_execution_state_id: str) -> None:
|
||||||
|
stats = self.get_stats(graph_execution_state_id)
|
||||||
|
logger.info(str(stats))
|
||||||
|
|
||||||
|
def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None:
|
||||||
|
stats = self.get_stats(graph_execution_state_id)
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
f.write(json.dumps(stats.as_dict(), indent=2))
|
||||||
|
|
||||||
|
def _get_model_cache_summary(self, graph_execution_state_id: str) -> ModelCacheStatsSummary:
|
||||||
try:
|
try:
|
||||||
graph_stats = self._stats[graph_execution_state_id]
|
|
||||||
cache_stats = self._cache_stats[graph_execution_state_id]
|
cache_stats = self._cache_stats[graph_execution_state_id]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.warning(f"Attempted to log statistics for unknown graph {graph_execution_state_id}: {e}.")
|
msg = f"Attempted to get model cache statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||||
return
|
logger.warning(msg)
|
||||||
|
raise GESStatsNotFoundError(msg)
|
||||||
|
|
||||||
log = graph_stats.get_pretty_log(graph_execution_state_id)
|
return ModelCacheStatsSummary(
|
||||||
|
cache_hits=cache_stats.hits,
|
||||||
|
cache_misses=cache_stats.misses,
|
||||||
|
high_water_mark_gb=cache_stats.high_watermark / GB,
|
||||||
|
cache_size_gb=cache_stats.cache_size / GB,
|
||||||
|
total_usage_gb=sum(list(cache_stats.loaded_model_sizes.values())) / GB,
|
||||||
|
models_cached=cache_stats.in_cache,
|
||||||
|
models_cleared=cache_stats.cleared,
|
||||||
|
)
|
||||||
|
|
||||||
hwm = cache_stats.high_watermark / GB
|
def _get_graph_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
|
||||||
tot = cache_stats.cache_size / GB
|
try:
|
||||||
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB
|
graph_stats = self._stats[graph_execution_state_id]
|
||||||
log += f"RAM used to load models: {loaded:4.2f}G\n"
|
except KeyError as e:
|
||||||
if torch.cuda.is_available():
|
msg = f"Attempted to get graph statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||||
log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n"
|
logger.warning(msg)
|
||||||
log += "RAM cache statistics:\n"
|
raise GESStatsNotFoundError(msg)
|
||||||
log += f" Model cache hits: {cache_stats.hits}\n"
|
|
||||||
log += f" Model cache misses: {cache_stats.misses}\n"
|
|
||||||
log += f" Models cached: {cache_stats.in_cache}\n"
|
|
||||||
log += f" Models cleared from cache: {cache_stats.cleared}\n"
|
|
||||||
log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
|
|
||||||
logger.info(log)
|
|
||||||
|
|
||||||
del self._stats[graph_execution_state_id]
|
return graph_stats.get_graph_stats_summary(graph_execution_state_id)
|
||||||
del self._cache_stats[graph_execution_state_id]
|
|
||||||
|
def _get_node_summaries(self, graph_execution_state_id: str) -> list[NodeExecutionStatsSummary]:
|
||||||
|
try:
|
||||||
|
graph_stats = self._stats[graph_execution_state_id]
|
||||||
|
except KeyError as e:
|
||||||
|
msg = f"Attempted to get node statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||||
|
logger.warning(msg)
|
||||||
|
raise GESStatsNotFoundError(msg)
|
||||||
|
|
||||||
|
return graph_stats.get_node_stats_summaries()
|
||||||
|
Loading…
Reference in New Issue
Block a user