feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk

Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`.

The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`.

The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped.

In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves.

This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often.

Tests updated.
This commit is contained in:
psychedelicious
2024-02-10 19:11:28 +11:00
parent 11f64dab38
commit fece935438
3 changed files with 50 additions and 71 deletions

View File

@ -1,12 +1,10 @@
import tempfile
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from unittest.mock import Mock
import pytest
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
@ -31,11 +29,6 @@ def fwd_cache(tmp_path: Path):
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
@pytest.fixture
def mock_invoker_with_logger():
return Mock(Invoker, services=Mock(logger=Mock(Logger)))
def test_obj_serializer_disk_initializes(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
assert obj_serializer._output_dir == tmp_path
@ -76,39 +69,33 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa
assert Path(obj_serializer._output_dir, obj_2_name).exists()
def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]):
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
assert obj_serializer._base_output_dir == tmp_path
assert obj_serializer._output_dir != tmp_path
assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name)
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
tempdir_path = obj_serializer._output_dir
del obj_serializer
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
tempdir_path = obj_serializer._output_dir
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
assert not tempdir_path.exists()
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
obj_2 = MockDataclass(foo="bar")
obj_2_name = obj_serializer.save(obj_2)
delete_all_result = obj_serializer._delete_all()
assert not Path(obj_serializer._output_dir, obj_1_name).exists()
assert not Path(obj_serializer._output_dir, obj_2_name).exists()
assert delete_all_result.deleted_count == 2
def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
assert obj_serializer._delete_on_startup is False
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
obj_serializer.start(mock_invoker_with_logger)
assert Path(tmp_path, obj_1_name).exists()
def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True)
assert obj_serializer._delete_on_startup is True
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
obj_serializer.start(mock_invoker_with_logger)
assert Path(obj_serializer._output_dir, obj_1_name).exists()
assert not Path(tmp_path, obj_1_name).exists()