feat(nodes): make delete on startup configurable for obj serializer

- The default is to not delete on startup - feels safer.
- The two services using this class _do_ delete on startup.
- The class has "ephemeral" removed from its name.
- Tests & app updated for this change.
This commit is contained in:
psychedelicious 2024-02-08 16:09:59 +11:00
parent 091f4cb583
commit 9edb995647
3 changed files with 67 additions and 27 deletions

View File

@ -5,7 +5,7 @@ from logging import Logger
import torch import torch
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.model_manager.metadata import ModelMetadataStore
@ -90,9 +90,11 @@ class ApiDependencies:
image_records = SqliteImageRecordStorage(db=db) image_records = SqliteImageRecordStorage(db=db)
images = ImageService() images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) tensors = ObjectSerializerForwardCache(
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True)
)
conditioning = ObjectSerializerForwardCache( conditioning = ObjectSerializerForwardCache(
ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True)
) )
model_manager = ModelManagerService(config, logger) model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db) model_record_service = ModelRecordServiceSQL(db=db)

View File

@ -22,20 +22,24 @@ class DeleteAllResult:
freed_space_bytes: float freed_space_bytes: float
class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): class ObjectSerializerDisk(ObjectSerializerBase[T]):
"""Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. """Provides a disk-backed storage for arbitrary python objects.
:param output_folder: The folder where the objects will be stored :param output_folder: The folder where the objects will be stored
:param delete_on_startup: If True, all objects in the output folder will be deleted on startup
""" """
def __init__(self, output_dir: Path): def __init__(self, output_dir: Path, delete_on_startup: bool = False):
super().__init__() super().__init__()
self._output_dir = output_dir self._output_dir = output_dir
self._output_dir.mkdir(parents=True, exist_ok=True) self._output_dir.mkdir(parents=True, exist_ok=True)
self._delete_on_startup = delete_on_startup
self.__obj_class_name: Optional[str] = None self.__obj_class_name: Optional[str] = None
def start(self, invoker: "Invoker") -> None: def start(self, invoker: "Invoker") -> None:
self._invoker = invoker self._invoker = invoker
if self._delete_on_startup:
delete_all_result = self._delete_all() delete_all_result = self._delete_all()
if delete_all_result.deleted_count > 0: if delete_all_result.deleted_count > 0:
freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2)

View File

@ -1,11 +1,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger
from pathlib import Path from pathlib import Path
from unittest.mock import Mock
import pytest import pytest
import torch import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
@ -14,22 +17,31 @@ class MockDataclass:
foo: str foo: str
def count_files(path: Path):
return len(list(path.iterdir()))
@pytest.fixture @pytest.fixture
def obj_serializer(tmp_path: Path): def obj_serializer(tmp_path: Path):
return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) return ObjectSerializerDisk[MockDataclass](tmp_path)
@pytest.fixture @pytest.fixture
def fwd_cache(tmp_path: Path): def fwd_cache(tmp_path: Path):
return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): @pytest.fixture
obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) def mock_invoker_with_logger():
return Mock(Invoker, services=Mock(logger=Mock(Logger)))
def test_obj_serializer_disk_initializes(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
assert obj_serializer._output_dir == tmp_path assert obj_serializer._output_dir == tmp_path
def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]):
obj_1 = MockDataclass(foo="bar") obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1) obj_1_name = obj_serializer.save(obj_1)
assert Path(obj_serializer._output_dir, obj_1_name).exists() assert Path(obj_serializer._output_dir, obj_1_name).exists()
@ -39,7 +51,7 @@ def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEph
assert Path(obj_serializer._output_dir, obj_2_name).exists() assert Path(obj_serializer._output_dir, obj_2_name).exists()
def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]):
obj_1 = MockDataclass(foo="bar") obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1) obj_1_name = obj_serializer.save(obj_1)
assert obj_serializer.load(obj_1_name).foo == "bar" assert obj_serializer.load(obj_1_name).foo == "bar"
@ -52,7 +64,7 @@ def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEph
obj_serializer.load("nonexistent_object_name") obj_serializer.load("nonexistent_object_name")
def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
obj_1 = MockDataclass(foo="bar") obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1) obj_1_name = obj_serializer.save(obj_1)
@ -64,7 +76,7 @@ def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerE
assert Path(obj_serializer._output_dir, obj_2_name).exists() assert Path(obj_serializer._output_dir, obj_2_name).exists()
def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]):
obj_1 = MockDataclass(foo="bar") obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1) obj_1_name = obj_serializer.save(obj_1)
@ -78,8 +90,30 @@ def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSeriali
assert delete_all_result.deleted_count == 2 assert delete_all_result.deleted_count == 2
def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
assert obj_serializer._delete_on_startup is False
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
obj_serializer.start(mock_invoker_with_logger)
assert Path(tmp_path, obj_1_name).exists()
def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True)
assert obj_serializer._delete_on_startup is True
obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1)
obj_serializer.start(mock_invoker_with_logger)
assert not Path(tmp_path, obj_1_name).exists()
def test_obj_serializer_disk_different_types(tmp_path: Path):
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
obj_1 = MockDataclass(foo="bar") obj_1 = MockDataclass(foo="bar")
obj_1_name = obj_serializer.save(obj_1) obj_1_name = obj_serializer.save(obj_1)
@ -88,17 +122,17 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path):
assert obj_1_loaded.foo == "bar" assert obj_1_loaded.foo == "bar"
assert obj_1_name.startswith("MockDataclass_") assert obj_1_name.startswith("MockDataclass_")
obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) obj_serializer = ObjectSerializerDisk[int](tmp_path)
obj_2_name = obj_serializer.save(9001) obj_2_name = obj_serializer.save(9001)
assert obj_serializer.load(obj_2_name) == 9001 assert obj_serializer.load(obj_2_name) == 9001
assert obj_2_name.startswith("int_") assert obj_2_name.startswith("int_")
obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) obj_serializer = ObjectSerializerDisk[str](tmp_path)
obj_3_name = obj_serializer.save("foo") obj_3_name = obj_serializer.save("foo")
assert obj_serializer.load(obj_3_name) == "foo" assert obj_serializer.load(obj_3_name) == "foo"
assert obj_3_name.startswith("str_") assert obj_3_name.startswith("str_")
obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path)
obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3]))
obj_4_loaded = obj_serializer.load(obj_4_name) obj_4_loaded = obj_serializer.load(obj_4_name)
assert isinstance(obj_4_loaded, torch.Tensor) assert isinstance(obj_4_loaded, torch.Tensor)
@ -106,7 +140,7 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path):
assert obj_4_name.startswith("Tensor_") assert obj_4_name.startswith("Tensor_")
def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
fwd_cache = ObjectSerializerForwardCache(obj_serializer) fwd_cache = ObjectSerializerForwardCache(obj_serializer)
assert fwd_cache._underlying_storage == obj_serializer assert fwd_cache._underlying_storage == obj_serializer