Add malloc info to MemorySnapshot.

This commit is contained in:
Ryan Dick 2023-10-03 11:37:47 -04:00
parent 2a3c0ab5d2
commit 75b65597af

View File

@ -4,13 +4,15 @@ from typing import Optional
import psutil import psutil
import torch import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB GB = 2**30 # 1 GB
class MemorySnapshot: class MemorySnapshot:
"""A snapshot of RAM and VRAM usage. All values are in bytes.""" """A snapshot of RAM and VRAM usage. All values are in bytes."""
def __init__(self, process_ram: int, vram: Optional[int]): def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]):
"""Initialize a MemorySnapshot. """Initialize a MemorySnapshot.
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`. Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
@ -18,9 +20,11 @@ class MemorySnapshot:
Args: Args:
process_ram (int): CPU RAM used by the current process. process_ram (int): CPU RAM used by the current process.
vram (Optional[int]): VRAM used by torch. vram (Optional[int]): VRAM used by torch.
malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil.
""" """
self.process_ram = process_ram self.process_ram = process_ram
self.vram = vram self.vram = vram
self.malloc_info = malloc_info
@classmethod @classmethod
def capture(cls, run_garbage_collector: bool = True): def capture(cls, run_garbage_collector: bool = True):
@ -49,7 +53,15 @@ class MemorySnapshot:
# time to test it properly. # time to test it properly.
vram = None vram = None
return cls(process_ram, vram) malloc_info = None
try:
malloc_info = LibcUtil().mallinfo2()
except Exception:
# TODO(ryand): Catch a more specific exception.
# This is expected in environments that do not have the 'libc.so.6' shared library.
pass
return cls(process_ram, vram, malloc_info)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str: def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str: