write RAM usage and change after each generation

This commit is contained in:
Lincoln Stein 2023-08-15 18:21:31 -04:00
parent d6c9bf5b38
commit a4b029d03c

View File

@ -29,6 +29,7 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme
writes to the system log is stored in InvocationServices.performance_statistics. writes to the system log is stored in InvocationServices.performance_statistics.
""" """
import psutil
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
@ -83,13 +84,14 @@ class InvocationStatsServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def update_invocation_stats( def update_invocation_stats(self,
self, graph_id: str,
graph_id: str, invocation_type: str,
invocation_type: str, time_used: float,
time_used: float, vram_used: float,
vram_used: float, ram_used: float,
): ram_changed: float,
):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
used internally. used internally.
@ -97,6 +99,8 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec) :param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB) :param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
""" """
pass pass
@ -140,18 +144,23 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.collector = collector self.collector = collector
self.graph_id = graph_id self.graph_id = graph_id
self.start_time = 0 self.start_time = 0
self.ram_info = None
def __enter__(self): def __enter__(self):
self.start_time = time.time() self.start_time = time.time()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
self.ram_info = psutil.virtual_memory()
def __exit__(self, *args): def __exit__(self, *args):
self.collector.update_invocation_stats( self.collector.update_invocation_stats(
self.graph_id, graph_id = self.graph_id,
self.invocation.type, invocation_type = self.invocation.type,
time.time() - self.start_time, time_used = time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0, vram_used = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
ram_used = psutil.virtual_memory().used / 1e9,
ram_changed = (psutil.virtual_memory().used - self.ram_info.used) / 1e9,
) )
def collect_stats( def collect_stats(
@ -179,13 +188,23 @@ class InvocationStatsService(InvocationStatsServiceBase):
except KeyError: except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float): def update_invocation_stats(self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
used internally. used internally.
:param graph_id: ID of the graph that is currently executing :param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection :param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
""" """
if not self._stats[graph_id].nodes.get(invocation_type): if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats() self._stats[graph_id].nodes[invocation_type] = NodeStats()
@ -193,6 +212,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
stats.calls += 1 stats.calls += 1
stats.time_used += time_used stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used) stats.max_vram = max(stats.max_vram, vram_used)
stats.ram_used = ram_used
stats.ram_changed = ram_changed
def log_stats(self): def log_stats(self):
""" """
@ -214,8 +235,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
total_time += stats.time_used total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("Current RAM used: " + "%4.2fG" % stats.ram_used + f" (delta={stats.ram_changed:4.2f}G)")
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) logger.info("Current VRAM used: " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id) completed.add(graph_id)