From d1d2d5a47d7cdfdb06b83d49ac090978b978f9ec Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 18 Aug 2023 20:06:09 +1000 Subject: [PATCH 1/2] fix(stats): fix fail case when previous graph is invalid When retrieving a graph, it is parsed through pydantic. It is possible that this graph is invalid, and an error is thrown. Handle this by deleting the failed graph from the stats if this occurs. --- invokeai/app/services/invocation_stats.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index e8557c40f7..bfede6a880 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -268,8 +268,14 @@ class InvocationStatsService(InvocationStatsServiceBase): is complete. """ completed = set() + errored = set() for graph_id, node_log in self._stats.items(): - current_graph_state = self.graph_execution_manager.get(graph_id) + try: + current_graph_state = self.graph_execution_manager.get(graph_id) + except Exception: + errored.add(graph_id) + continue + if not current_graph_state.is_complete(): continue @@ -302,3 +308,7 @@ class InvocationStatsService(InvocationStatsServiceBase): for graph_id in completed: del self._stats[graph_id] del self._cache_stats[graph_id] + + for graph_id in errored: + del self._stats[graph_id] + del self._cache_stats[graph_id] From 1b70bd13804737e2980486a4e9069d149b38b9b9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 18 Aug 2023 20:16:45 +1000 Subject: [PATCH 2/2] fix(stats): fix `InvocationStatsService` types - move docstrings to ABC - `start_time: int` -> `start_time: float` - remove class attribute assignments in `StatsContext` - add `update_mem_stats()` to ABC - add class attributes to ABC, because they are referenced in instances of the class. if they should not be on the ABC, then maybe there needs to be some restructuring --- invokeai/app/services/invocation_stats.py | 104 ++++++++++------------ 1 file changed, 47 insertions(+), 57 deletions(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index bfede6a880..b42d128b51 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats GIG = 1073741824 +@dataclass +class NodeStats: + """Class for tracking execution stats of an invocation node""" + + calls: int = 0 + time_used: float = 0.0 # seconds + max_vram: float = 0.0 # GB + cache_hits: int = 0 + cache_misses: int = 0 + cache_high_watermark: int = 0 + + +@dataclass +class NodeLog: + """Class for tracking node usage""" + + # {node_type => NodeStats} + nodes: Dict[str, NodeStats] = field(default_factory=dict) + + class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" + graph_execution_manager: ItemStorageABC["GraphExecutionState"] + # {graph_id => NodeLog} + _stats: Dict[str, NodeLog] + _cache_stats: Dict[str, CacheStats] + ram_used: float + ram_changed: float + @abstractmethod def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): """ @@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC): invocation_type: str, time_used: float, vram_used: float, - ram_used: float, - ram_changed: float, ): """ Add timing information on execution of a node. Usually @@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC): :param invocation_type: String literal type of the node :param time_used: Time used by node's exection (sec) :param vram_used: Maximum VRAM used during exection (GB) - :param ram_used: Current RAM available (GB) - :param ram_changed: Change in RAM usage over course of the run (GB) """ pass @@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC): """ pass + @abstractmethod + def update_mem_stats( + self, + ram_used: float, + ram_changed: float, + ): + """ + Update the collector with RAM memory usage info. -@dataclass -class NodeStats: - """Class for tracking execution stats of an invocation node""" - - calls: int = 0 - time_used: float = 0.0 # seconds - max_vram: float = 0.0 # GB - cache_hits: int = 0 - cache_misses: int = 0 - cache_high_watermark: int = 0 - - -@dataclass -class NodeLog: - """Class for tracking node usage""" - - # {node_type => NodeStats} - nodes: Dict[str, NodeStats] = field(default_factory=dict) + :param ram_used: How much RAM is currently in use. + :param ram_changed: How much RAM changed since last generation. + """ + pass class InvocationStatsService(InvocationStatsServiceBase): @@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase): class StatsContext: """Context manager for collecting statistics.""" - invocation: BaseInvocation = None - collector: "InvocationStatsServiceBase" = None - graph_id: str = None - start_time: int = 0 - ram_used: int = 0 - model_manager: ModelManagerService = None + invocation: BaseInvocation + collector: "InvocationStatsServiceBase" + graph_id: str + start_time: float + ram_used: int + model_manager: ModelManagerService def __init__( self, @@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase): self.invocation = invocation self.collector = collector self.graph_id = graph_id - self.start_time = 0 + self.start_time = 0.0 self.ram_used = 0 self.model_manager = model_manager @@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase): ) self.collector.update_invocation_stats( graph_id=self.graph_id, - invocation_type=self.invocation.type, + invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations time_used=time.time() - self.start_time, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, ) @@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase): graph_execution_state_id: str, model_manager: ModelManagerService, ) -> StatsContext: - """ - Return a context object that will capture the statistics. - :param invocation: BaseInvocation object from the current graph. - :param graph_execution_state: GraphExecutionState object from the current session. - """ if not self._stats.get(graph_execution_state_id): # first time we're seeing this self._stats[graph_execution_state_id] = NodeLog() self._cache_stats[graph_execution_state_id] = CacheStats() @@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase): self._stats = {} def reset_stats(self, graph_execution_id: str): - """Zero the statistics for the indicated graph.""" try: self._stats.pop(graph_execution_id) except KeyError: @@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase): ram_used: float, ram_changed: float, ): - """ - Update the collector with RAM memory usage info. - - :param ram_used: How much RAM is currently in use. - :param ram_changed: How much RAM changed since last generation. - """ self.ram_used = ram_used self.ram_changed = ram_changed @@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase): time_used: float, vram_used: float, ): - """ - Add timing information on execution of a node. Usually - used internally. - :param graph_id: ID of the graph that is currently executing - :param invocation_type: String literal type of the node - :param time_used: Time used by node's exection (sec) - :param vram_used: Maximum VRAM used during exection (GB) - :param ram_used: Current RAM available (GB) - :param ram_changed: Change in RAM usage over course of the run (GB) - """ if not self._stats[graph_id].nodes.get(invocation_type): self._stats[graph_id].nodes[invocation_type] = NodeStats() stats = self._stats[graph_id].nodes[invocation_type] @@ -262,11 +257,6 @@ class InvocationStatsService(InvocationStatsServiceBase): stats.max_vram = max(stats.max_vram, vram_used) def log_stats(self): - """ - Send the statistics to the system logger at the info level. - Stats will only be printed when the execution of the graph - is complete. - """ completed = set() errored = set() for graph_id, node_log in self._stats.items():