diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0c80494616..2acb961aa7 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -5,7 +5,7 @@ from logging import Logger import torch from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore @@ -90,9 +90,11 @@ class ApiDependencies: image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) + tensors = ObjectSerializerForwardCache( + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py similarity index 77% rename from invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py rename to invokeai/app/services/object_serializer/object_serializer_disk.py index 880848a142..174ff15192 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -22,26 +22,30 @@ class DeleteAllResult: freed_space_bytes: float -class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): - """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. +class ObjectSerializerDisk(ObjectSerializerBase[T]): + """Provides a disk-backed storage for arbitrary python objects. :param output_folder: The folder where the objects will be stored + :param delete_on_startup: If True, all objects in the output folder will be deleted on startup """ - def __init__(self, output_dir: Path): + def __init__(self, output_dir: Path, delete_on_startup: bool = False): super().__init__() self._output_dir = output_dir self._output_dir.mkdir(parents=True, exist_ok=True) + self._delete_on_startup = delete_on_startup self.__obj_class_name: Optional[str] = None def start(self, invoker: "Invoker") -> None: self._invoker = invoker - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) + + if self._delete_on_startup: + delete_all_result = self._delete_all() + if delete_all_result.deleted_count > 0: + freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) def load(self, name: str) -> T: file_path = self._get_path(name) diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_disk.py similarity index 65% rename from tests/test_object_serializer_ephemeral_disk.py rename to tests/test_object_serializer_disk.py index fffa65304f..5ce1e57901 100644 --- a/tests/test_object_serializer_ephemeral_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,11 +1,14 @@ from dataclasses import dataclass +from logging import Logger from pathlib import Path +from unittest.mock import Mock import pytest import torch +from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -14,22 +17,31 @@ class MockDataclass: foo: str +def count_files(path: Path): + return len(list(path.iterdir())) + + @pytest.fixture def obj_serializer(tmp_path: Path): - return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + return ObjectSerializerDisk[MockDataclass](tmp_path) @pytest.fixture def fwd_cache(tmp_path: Path): - return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) -def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +@pytest.fixture +def mock_invoker_with_logger(): + return Mock(Invoker, services=Mock(logger=Mock(Logger))) + + +def test_obj_serializer_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path -def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert Path(obj_serializer._output_dir, obj_1_name).exists() @@ -39,7 +51,7 @@ def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEph assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert obj_serializer.load(obj_1_name).foo == "bar" @@ -52,7 +64,7 @@ def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEph obj_serializer.load("nonexistent_object_name") -def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -64,7 +76,7 @@ def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerE assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -78,8 +90,30 @@ def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSeriali assert delete_all_result.deleted_count == 2 -def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) + assert obj_serializer._delete_on_startup is False + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) + assert obj_serializer._delete_on_startup is True + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert not Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -88,17 +122,17 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") - obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_serializer = ObjectSerializerDisk[int](tmp_path) obj_2_name = obj_serializer.save(9001) assert obj_serializer.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") - obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_serializer = ObjectSerializerDisk[str](tmp_path) obj_3_name = obj_serializer.save("foo") assert obj_serializer.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") - obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path) obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) obj_4_loaded = obj_serializer.load(obj_4_name) assert isinstance(obj_4_loaded, torch.Tensor) @@ -106,7 +140,7 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_4_name.startswith("Tensor_") -def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]): fwd_cache = ObjectSerializerForwardCache(obj_serializer) assert fwd_cache._underlying_storage == obj_serializer