From 8a4e5f73aa6e914494885438a560a7d6694b6ce2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 1 Aug 2023 19:39:42 -0400 Subject: [PATCH] reset stats on exception --- invokeai/app/services/invocation_stats.py | 38 +++++++++++------------ invokeai/app/services/processor.py | 4 +-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 8d41b60d49..24a5662647 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -43,14 +43,15 @@ from ..invocations.baseinvocation import BaseInvocation import invokeai.backend.util.logging as logger -class InvocationStats(): + +class InvocationStats: """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" def __init__(self): self._stats: Dict[str, int] = {} - - class StatsContext(): + + class StatsContext: def __init__(self, invocation: BaseInvocation, collector): self.invocation = invocation self.collector = collector @@ -61,17 +62,18 @@ class InvocationStats(): def __exit__(self, *args): self.collector.log_time(self.invocation.type, time.time() - self.start_time) - - def collect_stats(self, - invocation: BaseInvocation, - graph_execution_state: GraphExecutionState, - ) -> StatsContext: + + def collect_stats( + self, + invocation: BaseInvocation, + graph_execution_state: GraphExecutionState, + ) -> StatsContext: """ Return a context object that will capture the statistics. :param invocation: BaseInvocation object from the current graph. :param graph_execution_state: GraphExecutionState object from the current session. """ - if len(graph_execution_state.executed)==0: # new graph is starting + if len(graph_execution_state.executed) == 0: # new graph is starting self.reset_stats() self._current_graph_state = graph_execution_state sc = self.StatsContext(invocation, self) @@ -83,7 +85,6 @@ class InvocationStats(): torch.cuda.reset_peak_memory_stats() self._stats: Dict[str, List[int, float]] = {} - def log_time(self, invocation_type: str, time_used: float): """ Add timing information on execution of a node. Usually @@ -95,7 +96,7 @@ class InvocationStats(): self._stats[invocation_type] = [0, 0.0] self._stats[invocation_type][0] += 1 self._stats[invocation_type][1] += time_used - + def log_stats(self): """ Send the statistics to the system logger at the info level. @@ -103,13 +104,12 @@ class InvocationStats(): is complete. """ if self._current_graph_state.is_complete(): - logger.info('Node Calls Seconds') + logger.info("Node Calls Seconds") for node_type, (calls, time_used) in self._stats.items(): - logger.info(f'{node_type:<20} {calls:>5} {time_used:4.3f}s') - - total_time = sum([ticks for _,ticks in self._stats.values()]) - logger.info(f'TOTAL: {total_time:4.3f}s') + logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s") + + total_time = sum([ticks for _, ticks in self._stats.values()]) + logger.info(f"TOTAL: {total_time:4.3f}s") if torch.cuda.is_available(): - logger.info('Max VRAM used for execution: '+'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9)) - logger.info('Current VRAM utilization '+'%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) - + logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)) + logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index a43e2878ac..e9511aa283 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -116,6 +116,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass except CanceledException: + statistics.reset_stats() pass except Exception as e: @@ -137,7 +138,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=error, ) - + statistics.reset_stats() pass # Check queue to see if this is canceled, and skip if so @@ -165,4 +166,3 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor finally: self.__threadLimit.release() -