diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 2acb961aa7..0f2a92b5c8 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -91,10 +91,10 @@ class ApiDependencies: images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) tensors = ObjectSerializerForwardCache( - ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True) ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 06f86aa460..935fec3060 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -1,3 +1,4 @@ +import tempfile import typing from dataclasses import dataclass from pathlib import Path @@ -23,28 +24,24 @@ class DeleteAllResult: class ObjectSerializerDisk(ObjectSerializerBase[T]): - """Provides a disk-backed storage for arbitrary python objects. + """Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`. - :param output_folder: The folder where the objects will be stored - :param delete_on_startup: If True, all objects in the output folder will be deleted on startup + :param output_dir: The folder where the serialized objects will be stored + :param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit """ - def __init__(self, output_dir: Path, delete_on_startup: bool = False): + def __init__(self, output_dir: Path, ephemeral: bool = False): super().__init__() - self._output_dir = output_dir - self._output_dir.mkdir(parents=True, exist_ok=True) - self._delete_on_startup = delete_on_startup + self._ephemeral = ephemeral + self._base_output_dir = output_dir + self._base_output_dir.mkdir(parents=True, exist_ok=True) + # Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows + self._tempdir = ( + tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None + ) + self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir self.__obj_class_name: Optional[str] = None - def start(self, invoker: "Invoker") -> None: - if self._delete_on_startup: - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) - def load(self, name: str) -> T: file_path = self._get_path(name) try: @@ -75,19 +72,14 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]): def _new_name(self) -> str: return f"{self._obj_class_name}_{uuid_string()}" - def _delete_all(self) -> DeleteAllResult: - """ - Deletes all objects from disk. - """ + def _tempdir_cleanup(self) -> None: + """Calls `cleanup` on the temporary directory, if it exists.""" + if self._tempdir: + self._tempdir.cleanup() - # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have - # to manually clear them on startup anyways. This is a bit simpler and more reliable. + def __del__(self) -> None: + # In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd. + self._tempdir_cleanup() - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_dir).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - return DeleteAllResult(deleted_count, freed_space) + def stop(self, invoker: "Invoker") -> None: + self._tempdir_cleanup() diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 2bc7e16937..125534c500 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,12 +1,10 @@ +import tempfile from dataclasses import dataclass -from logging import Logger from pathlib import Path -from unittest.mock import Mock import pytest import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -31,11 +29,6 @@ def fwd_cache(tmp_path: Path): return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) -@pytest.fixture -def mock_invoker_with_logger(): - return Mock(Invoker, services=Mock(logger=Mock(Logger))) - - def test_obj_serializer_disk_initializes(tmp_path: Path): obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path @@ -76,39 +69,33 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): +def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory) + assert obj_serializer._base_output_dir == tmp_path + assert obj_serializer._output_dir != tmp_path + assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name) + + +def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + del obj_serializer + assert not tempdir_path.exists() + + +def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + obj_serializer.stop(None) # pyright: ignore [reportArgumentType] + assert not tempdir_path.exists() + + +def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) - - obj_2 = MockDataclass(foo="bar") - obj_2_name = obj_serializer.save(obj_2) - - delete_all_result = obj_serializer._delete_all() - - assert not Path(obj_serializer._output_dir, obj_1_name).exists() - assert not Path(obj_serializer._output_dir, obj_2_name).exists() - assert delete_all_result.deleted_count == 2 - - -def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) - assert obj_serializer._delete_on_startup is False - - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) - assert Path(tmp_path, obj_1_name).exists() - - -def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) - assert obj_serializer._delete_on_startup is True - - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) + assert Path(obj_serializer._output_dir, obj_1_name).exists() assert not Path(tmp_path, obj_1_name).exists()