From 763dcacfd3f745563cd6ab88562bfe337443886e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 3 Oct 2023 14:20:46 -0400 Subject: [PATCH] Add unit test for get_pretty_snapshot_diff(...). --- .../model_management/memory_snapshot.py | 9 +++-- .../model_management/test_memory_snapshot.py | 35 +++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) create mode 100644 tests/backend/model_management/test_memory_snapshot.py diff --git a/invokeai/backend/model_management/memory_snapshot.py b/invokeai/backend/model_management/memory_snapshot.py index af47a0d0bf..4f43affcf7 100644 --- a/invokeai/backend/model_management/memory_snapshot.py +++ b/invokeai/backend/model_management/memory_snapshot.py @@ -65,10 +65,8 @@ class MemorySnapshot: def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str: """Get a pretty string describing the difference between two `MemorySnapshot`s.""" - def get_msg_line(prefix: str, val1: Optional[int], val2: Optional[int]): - diff = None - if val1 is not None and val2 is not None: - diff = val2 - val1 + def get_msg_line(prefix: str, val1: int, val2: int): + diff = val2 - val1 return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" msg = "" @@ -90,6 +88,7 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2) - msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) + if snapshot_1.vram is not None and snapshot_2.vram is not None: + msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) return msg diff --git a/tests/backend/model_management/test_memory_snapshot.py b/tests/backend/model_management/test_memory_snapshot.py new file mode 100644 index 0000000000..80aed7b7ba --- /dev/null +++ b/tests/backend/model_management/test_memory_snapshot.py @@ -0,0 +1,35 @@ +import pytest + +from invokeai.backend.model_management.libc_util import Struct_mallinfo2 +from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff + + +def test_memory_snapshot_capture(): + """Smoke test of MemorySnapshot.capture().""" + snapshot = MemorySnapshot.capture() + + # We just check process_ram, because it is the only field that should be supported on all platforms. + assert snapshot.process_ram > 0 + + +snapshots = [ + MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=Struct_mallinfo2()), + MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None), + MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()), + MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None), +] + + +@pytest.mark.parametrize("snapshot_1", snapshots) +@pytest.mark.parametrize("snapshot_2", snapshots) +def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2): + """Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields.""" + msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2) + + expected_lines = 1 + if snapshot_1.vram is not None and snapshot_2.vram is not None: + expected_lines += 1 + if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: + expected_lines += 5 + + assert len(msg.splitlines()) == expected_lines