reset stats on exception

This commit is contained in:
Lincoln Stein 2023-08-01 19:39:42 -04:00
parent fd7b842419
commit 8a4e5f73aa
2 changed files with 21 additions and 21 deletions

View File

@ -43,14 +43,15 @@ from ..invocations.baseinvocation import BaseInvocation
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
class InvocationStats():
class InvocationStats:
"""Accumulate performance information about a running graph. Collects time spent in each node, """Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems""" as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self): def __init__(self):
self._stats: Dict[str, int] = {} self._stats: Dict[str, int] = {}
class StatsContext(): class StatsContext:
def __init__(self, invocation: BaseInvocation, collector): def __init__(self, invocation: BaseInvocation, collector):
self.invocation = invocation self.invocation = invocation
self.collector = collector self.collector = collector
@ -61,17 +62,18 @@ class InvocationStats():
def __exit__(self, *args): def __exit__(self, *args):
self.collector.log_time(self.invocation.type, time.time() - self.start_time) self.collector.log_time(self.invocation.type, time.time() - self.start_time)
def collect_stats(self, def collect_stats(
invocation: BaseInvocation, self,
graph_execution_state: GraphExecutionState, invocation: BaseInvocation,
) -> StatsContext: graph_execution_state: GraphExecutionState,
) -> StatsContext:
""" """
Return a context object that will capture the statistics. Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph. :param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session. :param graph_execution_state: GraphExecutionState object from the current session.
""" """
if len(graph_execution_state.executed)==0: # new graph is starting if len(graph_execution_state.executed) == 0: # new graph is starting
self.reset_stats() self.reset_stats()
self._current_graph_state = graph_execution_state self._current_graph_state = graph_execution_state
sc = self.StatsContext(invocation, self) sc = self.StatsContext(invocation, self)
@ -83,7 +85,6 @@ class InvocationStats():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
self._stats: Dict[str, List[int, float]] = {} self._stats: Dict[str, List[int, float]] = {}
def log_time(self, invocation_type: str, time_used: float): def log_time(self, invocation_type: str, time_used: float):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
@ -95,7 +96,7 @@ class InvocationStats():
self._stats[invocation_type] = [0, 0.0] self._stats[invocation_type] = [0, 0.0]
self._stats[invocation_type][0] += 1 self._stats[invocation_type][0] += 1
self._stats[invocation_type][1] += time_used self._stats[invocation_type][1] += time_used
def log_stats(self): def log_stats(self):
""" """
Send the statistics to the system logger at the info level. Send the statistics to the system logger at the info level.
@ -103,13 +104,12 @@ class InvocationStats():
is complete. is complete.
""" """
if self._current_graph_state.is_complete(): if self._current_graph_state.is_complete():
logger.info('Node Calls Seconds') logger.info("Node Calls Seconds")
for node_type, (calls, time_used) in self._stats.items(): for node_type, (calls, time_used) in self._stats.items():
logger.info(f'{node_type:<20} {calls:>5} {time_used:4.3f}s') logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s")
total_time = sum([ticks for _,ticks in self._stats.values()]) total_time = sum([ticks for _, ticks in self._stats.values()])
logger.info(f'TOTAL: {total_time:4.3f}s') logger.info(f"TOTAL: {total_time:4.3f}s")
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info('Max VRAM used for execution: '+'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9)) logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9))
logger.info('Current VRAM utilization '+'%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))

View File

@ -116,6 +116,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass pass
except CanceledException: except CanceledException:
statistics.reset_stats()
pass pass
except Exception as e: except Exception as e:
@ -137,7 +138,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )
statistics.reset_stats()
pass pass
# Check queue to see if this is canceled, and skip if so # Check queue to see if this is canceled, and skip if so
@ -165,4 +166,3 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
finally: finally:
self.__threadLimit.release() self.__threadLimit.release()