from dataclasses import dataclass from logging import Logger from pathlib import Path from unittest.mock import Mock import pytest import torch from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @dataclass class MockDataclass: foo: str def count_files(path: Path): return len(list(path.iterdir())) @pytest.fixture def obj_serializer(tmp_path: Path): return ObjectSerializerDisk[MockDataclass](tmp_path) @pytest.fixture def fwd_cache(tmp_path: Path): return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) @pytest.fixture def mock_invoker_with_logger(): return Mock(Invoker, services=Mock(logger=Mock(Logger))) def test_obj_serializer_disk_initializes(tmp_path: Path): obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert Path(obj_serializer._output_dir, obj_1_name).exists() obj_2 = MockDataclass(foo="baz") obj_2_name = obj_serializer.save(obj_2) assert Path(obj_serializer._output_dir, obj_2_name).exists() def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert obj_serializer.load(obj_1_name).foo == "bar" obj_2 = MockDataclass(foo="baz") obj_2_name = obj_serializer.save(obj_2) assert obj_serializer.load(obj_2_name).foo == "baz" with pytest.raises(ObjectNotFoundError): obj_serializer.load("nonexistent_object_name") def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) obj_2 = MockDataclass(foo="bar") obj_2_name = obj_serializer.save(obj_2) obj_serializer.delete(obj_1_name) assert not Path(obj_serializer._output_dir, obj_1_name).exists() assert Path(obj_serializer._output_dir, obj_2_name).exists() def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) obj_2 = MockDataclass(foo="bar") obj_2_name = obj_serializer.save(obj_2) delete_all_result = obj_serializer._delete_all() assert not Path(obj_serializer._output_dir, obj_1_name).exists() assert not Path(obj_serializer._output_dir, obj_2_name).exists() assert delete_all_result.deleted_count == 2 def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._delete_on_startup is False obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) obj_serializer.start(mock_invoker_with_logger) assert Path(tmp_path, obj_1_name).exists() def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) assert obj_serializer._delete_on_startup is True obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) obj_serializer.start(mock_invoker_with_logger) assert not Path(tmp_path, obj_1_name).exists() def test_obj_serializer_disk_different_types(tmp_path: Path): obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer_1.save(obj_1) obj_1_loaded = obj_serializer_1.load(obj_1_name) assert obj_serializer_1._obj_class_name == "MockDataclass" assert isinstance(obj_1_loaded, MockDataclass) assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") obj_serializer_2 = ObjectSerializerDisk[int](tmp_path) obj_2_name = obj_serializer_2.save(9001) assert obj_serializer_2._obj_class_name == "int" assert obj_serializer_2.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") obj_serializer_3 = ObjectSerializerDisk[str](tmp_path) obj_3_name = obj_serializer_3.save("foo") assert obj_serializer_3._obj_class_name == "str" assert obj_serializer_3.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path) obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3])) obj_4_loaded = obj_serializer_4.load(obj_4_name) assert obj_serializer_4._obj_class_name == "Tensor" assert isinstance(obj_4_loaded, torch.Tensor) assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) assert obj_4_name.startswith("Tensor_") def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]): fwd_cache = ObjectSerializerForwardCache(obj_serializer) assert fwd_cache._underlying_storage == obj_serializer def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): obj = MockDataclass(foo="bar") obj_name = fwd_cache.save(obj) obj_loaded = fwd_cache.load(obj_name) obj_underlying = fwd_cache._underlying_storage.load(obj_name) assert obj_loaded == obj_underlying assert obj_loaded.foo == "bar" def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = fwd_cache.save(obj_1) obj_2 = MockDataclass(foo="baz") obj_2_name = fwd_cache.save(obj_2) obj_3 = MockDataclass(foo="qux") obj_3_name = fwd_cache.save(obj_3) assert obj_1_name not in fwd_cache._cache assert obj_2_name in fwd_cache._cache assert obj_3_name in fwd_cache._cache # apparently qsize is "not reliable"? assert fwd_cache._cache_ids.qsize() == 2 def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): called_name = None obj_1 = MockDataclass(foo="bar") def on_deleted(name: str): nonlocal called_name called_name = name fwd_cache.on_deleted(on_deleted) obj_1_name = fwd_cache.save(obj_1) fwd_cache.delete(obj_1_name) assert called_name == obj_1_name