tests: test ObjectSerializerDisk class name extraction

This commit is contained in:
psychedelicious 2024-02-10 18:46:51 +11:00
parent 670f2f75e9
commit 11f64dab38

View File

@ -113,28 +113,31 @@ def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with
def test_obj_serializer_disk_different_types(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
obj_1_loaded = obj_serializer.load(obj_1_name)
obj_1_name = obj_serializer_1.save(obj_1)
obj_1_loaded = obj_serializer_1.load(obj_1_name)
assert obj_serializer_1._obj_class_name == "MockDataclass"
assert isinstance(obj_1_loaded, MockDataclass)
assert obj_1_loaded.foo == "bar"
assert obj_1_name.startswith("MockDataclass_")
obj_serializer = ObjectSerializerDisk[int](tmp_path)
obj_2_name = obj_serializer.save(9001)
assert obj_serializer.load(obj_2_name) == 9001
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
obj_2_name = obj_serializer_2.save(9001)
assert obj_serializer_2._obj_class_name == "int"
assert obj_serializer_2.load(obj_2_name) == 9001
assert obj_2_name.startswith("int_")
obj_serializer = ObjectSerializerDisk[str](tmp_path)
obj_3_name = obj_serializer.save("foo")
assert obj_serializer.load(obj_3_name) == "foo"
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
obj_3_name = obj_serializer_3.save("foo")
assert obj_serializer_3._obj_class_name == "str"
assert obj_serializer_3.load(obj_3_name) == "foo"
assert obj_3_name.startswith("str_")
obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path)
obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3]))
obj_4_loaded = obj_serializer.load(obj_4_name)
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
obj_4_loaded = obj_serializer_4.load(obj_4_name)
assert obj_serializer_4._obj_class_name == "Tensor"
assert isinstance(obj_4_loaded, torch.Tensor)
assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3]))
assert obj_4_name.startswith("Tensor_")