fix(stats): fix InvocationStatsService types

- move docstrings to ABC
- `start_time: int` -> `start_time: float`
- remove class attribute assignments in `StatsContext`
- add `update_mem_stats()` to ABC
- add class attributes to ABC, because they are referenced in instances of the class. if they should not be on the ABC, then maybe there needs to be some restructuring
This commit is contained in:
psychedelicious 2023-08-18 20:16:45 +10:00
parent d1d2d5a47d
commit 1b70bd1380

View File

@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
GIG = 1073741824 GIG = 1073741824
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsServiceBase(ABC): class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics" "Abstract base class for recording node memory/time performance statistics"
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod @abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
""" """
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str, invocation_type: str,
time_used: float, time_used: float,
vram_used: float, vram_used: float,
ram_used: float,
ram_changed: float,
): ):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec) :param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB) :param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
""" """
pass pass
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
""" """
pass pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
@dataclass :param ram_used: How much RAM is currently in use.
class NodeStats: :param ram_changed: How much RAM changed since last generation.
"""Class for tracking execution stats of an invocation node""" """
pass
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase): class InvocationStatsService(InvocationStatsServiceBase):
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
class StatsContext: class StatsContext:
"""Context manager for collecting statistics.""" """Context manager for collecting statistics."""
invocation: BaseInvocation = None invocation: BaseInvocation
collector: "InvocationStatsServiceBase" = None collector: "InvocationStatsServiceBase"
graph_id: str = None graph_id: str
start_time: int = 0 start_time: float
ram_used: int = 0 ram_used: int
model_manager: ModelManagerService = None model_manager: ModelManagerService
def __init__( def __init__(
self, self,
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.invocation = invocation self.invocation = invocation
self.collector = collector self.collector = collector
self.graph_id = graph_id self.graph_id = graph_id
self.start_time = 0 self.start_time = 0.0
self.ram_used = 0 self.ram_used = 0
self.model_manager = model_manager self.model_manager = model_manager
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
) )
self.collector.update_invocation_stats( self.collector.update_invocation_stats(
graph_id=self.graph_id, graph_id=self.graph_id,
invocation_type=self.invocation.type, invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time, time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
) )
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_execution_state_id: str, graph_execution_state_id: str,
model_manager: ModelManagerService, model_manager: ModelManagerService,
) -> StatsContext: ) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog() self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats() self._cache_stats[graph_execution_state_id] = CacheStats()
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._stats = {} self._stats = {}
def reset_stats(self, graph_execution_id: str): def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try: try:
self._stats.pop(graph_execution_id) self._stats.pop(graph_execution_id)
except KeyError: except KeyError:
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
ram_used: float, ram_used: float,
ram_changed: float, ram_changed: float,
): ):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used self.ram_used = ram_used
self.ram_changed = ram_changed self.ram_changed = ram_changed
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
time_used: float, time_used: float,
vram_used: float, vram_used: float,
): ):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type): if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats() self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type] stats = self._stats[graph_id].nodes[invocation_type]
@ -262,11 +257,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
stats.max_vram = max(stats.max_vram, vram_used) stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self): def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set() completed = set()
errored = set() errored = set()
for graph_id, node_log in self._stats.items(): for graph_id, node_log in self._stats.items():