mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add unit test for get_pretty_snapshot_diff(...).
This commit is contained in:
parent
3599d546e6
commit
763dcacfd3
@ -65,10 +65,8 @@ class MemorySnapshot:
|
|||||||
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
|
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
|
||||||
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
||||||
|
|
||||||
def get_msg_line(prefix: str, val1: Optional[int], val2: Optional[int]):
|
def get_msg_line(prefix: str, val1: int, val2: int):
|
||||||
diff = None
|
diff = val2 - val1
|
||||||
if val1 is not None and val2 is not None:
|
|
||||||
diff = val2 - val1
|
|
||||||
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
|
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
|
||||||
|
|
||||||
msg = ""
|
msg = ""
|
||||||
@ -90,6 +88,7 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps
|
|||||||
libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
|
libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
|
||||||
msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)
|
msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)
|
||||||
|
|
||||||
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||||
|
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
35
tests/backend/model_management/test_memory_snapshot.py
Normal file
35
tests/backend/model_management/test_memory_snapshot.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
|
||||||
|
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_snapshot_capture():
|
||||||
|
"""Smoke test of MemorySnapshot.capture()."""
|
||||||
|
snapshot = MemorySnapshot.capture()
|
||||||
|
|
||||||
|
# We just check process_ram, because it is the only field that should be supported on all platforms.
|
||||||
|
assert snapshot.process_ram > 0
|
||||||
|
|
||||||
|
|
||||||
|
snapshots = [
|
||||||
|
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=Struct_mallinfo2()),
|
||||||
|
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None),
|
||||||
|
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()),
|
||||||
|
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("snapshot_1", snapshots)
|
||||||
|
@pytest.mark.parametrize("snapshot_2", snapshots)
|
||||||
|
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
||||||
|
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||||
|
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||||
|
|
||||||
|
expected_lines = 1
|
||||||
|
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||||
|
expected_lines += 1
|
||||||
|
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||||
|
expected_lines += 5
|
||||||
|
|
||||||
|
assert len(msg.splitlines()) == expected_lines
|
Loading…
x
Reference in New Issue
Block a user