diff --git a/invokeai/backend/model_management/memory_snapshot.py b/invokeai/backend/model_management/memory_snapshot.py index 01f1328114..fe54af191c 100644 --- a/invokeai/backend/model_management/memory_snapshot.py +++ b/invokeai/backend/model_management/memory_snapshot.py @@ -64,7 +64,7 @@ class MemorySnapshot: return cls(process_ram, vram, malloc_info) -def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str: +def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: """Get a pretty string describing the difference between two `MemorySnapshot`s.""" def get_msg_line(prefix: str, val1: int, val2: int): @@ -73,6 +73,9 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps msg = "" + if snapshot_1 is None or snapshot_2 is None: + return msg + msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: diff --git a/tests/backend/model_management/test_memory_snapshot.py b/tests/backend/model_management/test_memory_snapshot.py index 80aed7b7ba..dcbb173e96 100644 --- a/tests/backend/model_management/test_memory_snapshot.py +++ b/tests/backend/model_management/test_memory_snapshot.py @@ -17,6 +17,7 @@ snapshots = [ MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None), MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()), MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None), + None, ] @@ -26,10 +27,12 @@ def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2): """Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields.""" msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2) - expected_lines = 1 - if snapshot_1.vram is not None and snapshot_2.vram is not None: + expected_lines = 0 + if snapshot_1 is not None and snapshot_2 is not None: expected_lines += 1 - if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: - expected_lines += 5 + if snapshot_1.vram is not None and snapshot_2.vram is not None: + expected_lines += 1 + if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: + expected_lines += 5 assert len(msg.splitlines()) == expected_lines