From aa45d21fd2fd55fa3bfd4b82269af9b2b2327923 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 11 Jan 2024 13:03:03 -0500 Subject: [PATCH] Reduce the number of graph_execution_manager.get(...) calls from the InvocationStatsService. --- .../invocation_processor_default.py | 2 +- .../invocation_stats/invocation_stats_base.py | 2 +- .../invocation_stats_default.py | 84 +++++++++++-------- 3 files changed, 51 insertions(+), 37 deletions(-) diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index 50657b0984..09608dca2b 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -132,7 +132,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC): source_node_id=source_node_id, result=outputs.model_dump(), ) - self.__invoker.services.performance_statistics.log_stats() except KeyboardInterrupt: pass @@ -195,6 +194,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error=traceback.format_exc(), ) elif is_complete: + self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id) self.__invoker.services.events.emit_graph_execution_complete( queue_batch_id=queue_item.session_queue_batch_id, queue_item_id=queue_item.session_queue_item_id, diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index 365bd9cb6c..6e5b6a9f69 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -67,7 +67,7 @@ class InvocationStatsServiceBase(ABC): pass @abstractmethod - def log_stats(self): + def log_stats(self, graph_execution_state_id: str): """ Write out the accumulated statistics to the log or somewhere else. """ diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 7e89e67e5e..ddc05ffe1c 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -36,6 +36,9 @@ class InvocationStatsService(InvocationStatsServiceBase): self._stats[graph_execution_state_id] = GraphExecutionStats() self._cache_stats[graph_execution_state_id] = CacheStats() + # Prune stale stats. There should be none since we're starting a new graph, but just in case. + self._prune_stale_stats() + # Record state before the invocation. start_time = time.time() start_ram = psutil.Process().memory_info().rss @@ -59,49 +62,60 @@ class InvocationStatsService(InvocationStatsServiceBase): ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def reset_stats(self, graph_execution_id: str): - try: - self._stats.pop(graph_execution_id) - except KeyError: - logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") + def _prune_stale_stats(self): + """Check all graphs being tracked and prune any that have completed/errored. - def log_stats(self): - completed = set() - errored = set() - for graph_id, _node_log in self._stats.items(): + This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so + for now we call this function periodically to prevent them from accumulating. + """ + to_prune = [] + for graph_execution_state_id in self._stats: try: - current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id) + graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id) except Exception: - errored.add(graph_id) + # TODO(ryand): What would cause this? Should this exception just be allowed to propagate? + logger.warning(f"Failed to get graph state for {graph_execution_state_id}.") continue - if not current_graph_state.is_complete(): + if not graph_execution_state.is_complete(): + # The graph is still running, don't prune it. continue - graph_stats = self._stats[graph_id] - log = graph_stats.get_pretty_log(graph_id) + to_prune.append(graph_execution_state_id) - cache_stats = self._cache_stats[graph_id] - hwm = cache_stats.high_watermark / GB - tot = cache_stats.cache_size / GB - loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB - log += f"RAM used to load models: {loaded:4.2f}G\n" - if torch.cuda.is_available(): - log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n" - log += "RAM cache statistics:\n" - log += f" Model cache hits: {cache_stats.hits}\n" - log += f" Model cache misses: {cache_stats.misses}\n" - log += f" Models cached: {cache_stats.in_cache}\n" - log += f" Models cleared from cache: {cache_stats.cleared}\n" - log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n" - logger.info(log) + for graph_execution_state_id in to_prune: + del self._stats[graph_execution_state_id] + del self._cache_stats[graph_execution_state_id] - completed.add(graph_id) + if len(to_prune) > 0: + logger.info(f"Pruned stale graph stats for {to_prune}.") - for graph_id in completed: - del self._stats[graph_id] - del self._cache_stats[graph_id] + def reset_stats(self, graph_execution_state_id: str): + try: + del self._stats[graph_execution_state_id] + del self._cache_stats[graph_execution_state_id] + except KeyError as e: + logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.") - for graph_id in errored: - del self._stats[graph_id] - del self._cache_stats[graph_id] + def log_stats(self, graph_execution_state_id: str): + graph_stats = self._stats[graph_execution_state_id] + cache_stats = self._cache_stats[graph_execution_state_id] + + log = graph_stats.get_pretty_log(graph_execution_state_id) + + hwm = cache_stats.high_watermark / GB + tot = cache_stats.cache_size / GB + loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB + log += f"RAM used to load models: {loaded:4.2f}G\n" + if torch.cuda.is_available(): + log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n" + log += "RAM cache statistics:\n" + log += f" Model cache hits: {cache_stats.hits}\n" + log += f" Model cache misses: {cache_stats.misses}\n" + log += f" Models cached: {cache_stats.in_cache}\n" + log += f" Models cleared from cache: {cache_stats.cleared}\n" + log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n" + logger.info(log) + + del self._stats[graph_execution_state_id] + del self._cache_stats[graph_execution_state_id]