diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_ephemeral_disk.py new file mode 100644 index 0000000000..fffa65304f --- /dev/null +++ b/tests/test_object_serializer_ephemeral_disk.py @@ -0,0 +1,148 @@ +from dataclasses import dataclass +from pathlib import Path + +import pytest +import torch + +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache + + +@dataclass +class MockDataclass: + foo: str + + +@pytest.fixture +def obj_serializer(tmp_path: Path): + return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + +@pytest.fixture +def fwd_cache(tmp_path: Path): + return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + + +def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + assert obj_serializer._output_dir == tmp_path + + +def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert Path(obj_serializer._output_dir, obj_1_name).exists() + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert obj_serializer.load(obj_1_name).foo == "bar" + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert obj_serializer.load(obj_2_name).foo == "baz" + + with pytest.raises(ObjectNotFoundError): + obj_serializer.load("nonexistent_object_name") + + +def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + obj_serializer.delete(obj_1_name) + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + delete_all_result = obj_serializer._delete_all() + + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert not Path(obj_serializer._output_dir, obj_2_name).exists() + assert delete_all_result.deleted_count == 2 + + +def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + obj_1_loaded = obj_serializer.load(obj_1_name) + assert isinstance(obj_1_loaded, MockDataclass) + assert obj_1_loaded.foo == "bar" + assert obj_1_name.startswith("MockDataclass_") + + obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_2_name = obj_serializer.save(9001) + assert obj_serializer.load(obj_2_name) == 9001 + assert obj_2_name.startswith("int_") + + obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_3_name = obj_serializer.save("foo") + assert obj_serializer.load(obj_3_name) == "foo" + assert obj_3_name.startswith("str_") + + obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer.load(obj_4_name) + assert isinstance(obj_4_loaded, torch.Tensor) + assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) + assert obj_4_name.startswith("Tensor_") + + +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + fwd_cache = ObjectSerializerForwardCache(obj_serializer) + assert fwd_cache._underlying_storage == obj_serializer + + +def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj = MockDataclass(foo="bar") + obj_name = fwd_cache.save(obj) + obj_loaded = fwd_cache.load(obj_name) + obj_underlying = fwd_cache._underlying_storage.load(obj_name) + assert obj_loaded == obj_underlying + assert obj_loaded.foo == "bar" + + +def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = fwd_cache.save(obj_1) + obj_2 = MockDataclass(foo="baz") + obj_2_name = fwd_cache.save(obj_2) + obj_3 = MockDataclass(foo="qux") + obj_3_name = fwd_cache.save(obj_3) + assert obj_1_name not in fwd_cache._cache + assert obj_2_name in fwd_cache._cache + assert obj_3_name in fwd_cache._cache + # apparently qsize is "not reliable"? + assert fwd_cache._cache_ids.qsize() == 2 + + +def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + called_name = None + obj_1 = MockDataclass(foo="bar") + + def on_deleted(name: str): + nonlocal called_name + called_name = name + + fwd_cache.on_deleted(on_deleted) + obj_1_name = fwd_cache.save(obj_1) + fwd_cache.delete(obj_1_name) + assert called_name == obj_1_name