mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests: test ObjectSerializerDisk class name extraction
This commit is contained in:
parent
dc003a4bac
commit
cda9ab7933
@ -113,28 +113,31 @@ def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with
|
|||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||||
|
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
obj_1_name = obj_serializer_1.save(obj_1)
|
||||||
obj_1_loaded = obj_serializer.load(obj_1_name)
|
obj_1_loaded = obj_serializer_1.load(obj_1_name)
|
||||||
|
assert obj_serializer_1._obj_class_name == "MockDataclass"
|
||||||
assert isinstance(obj_1_loaded, MockDataclass)
|
assert isinstance(obj_1_loaded, MockDataclass)
|
||||||
assert obj_1_loaded.foo == "bar"
|
assert obj_1_loaded.foo == "bar"
|
||||||
assert obj_1_name.startswith("MockDataclass_")
|
assert obj_1_name.startswith("MockDataclass_")
|
||||||
|
|
||||||
obj_serializer = ObjectSerializerDisk[int](tmp_path)
|
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
|
||||||
obj_2_name = obj_serializer.save(9001)
|
obj_2_name = obj_serializer_2.save(9001)
|
||||||
assert obj_serializer.load(obj_2_name) == 9001
|
assert obj_serializer_2._obj_class_name == "int"
|
||||||
|
assert obj_serializer_2.load(obj_2_name) == 9001
|
||||||
assert obj_2_name.startswith("int_")
|
assert obj_2_name.startswith("int_")
|
||||||
|
|
||||||
obj_serializer = ObjectSerializerDisk[str](tmp_path)
|
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
|
||||||
obj_3_name = obj_serializer.save("foo")
|
obj_3_name = obj_serializer_3.save("foo")
|
||||||
assert obj_serializer.load(obj_3_name) == "foo"
|
assert obj_serializer_3._obj_class_name == "str"
|
||||||
|
assert obj_serializer_3.load(obj_3_name) == "foo"
|
||||||
assert obj_3_name.startswith("str_")
|
assert obj_3_name.startswith("str_")
|
||||||
|
|
||||||
obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path)
|
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
|
||||||
obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3]))
|
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
|
||||||
obj_4_loaded = obj_serializer.load(obj_4_name)
|
obj_4_loaded = obj_serializer_4.load(obj_4_name)
|
||||||
|
assert obj_serializer_4._obj_class_name == "Tensor"
|
||||||
assert isinstance(obj_4_loaded, torch.Tensor)
|
assert isinstance(obj_4_loaded, torch.Tensor)
|
||||||
assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3]))
|
assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3]))
|
||||||
assert obj_4_name.startswith("Tensor_")
|
assert obj_4_name.startswith("Tensor_")
|
||||||
|
Loading…
Reference in New Issue
Block a user