mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): make delete on startup configurable for obj serializer
- The default is to not delete on startup - feels safer. - The two services using this class _do_ delete on startup. - The class has "ephemeral" removed from its name. - Tests & app updated for this change.
This commit is contained in:
parent
091f4cb583
commit
9edb995647
@ -5,7 +5,7 @@ from logging import Logger
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||||
from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||||
@ -90,9 +90,11 @@ class ApiDependencies:
|
|||||||
image_records = SqliteImageRecordStorage(db=db)
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
images = ImageService()
|
images = ImageService()
|
||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors"))
|
tensors = ObjectSerializerForwardCache(
|
||||||
|
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True)
|
||||||
|
)
|
||||||
conditioning = ObjectSerializerForwardCache(
|
conditioning = ObjectSerializerForwardCache(
|
||||||
ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning")
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True)
|
||||||
)
|
)
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
model_record_service = ModelRecordServiceSQL(db=db)
|
||||||
|
@ -22,26 +22,30 @@ class DeleteAllResult:
|
|||||||
freed_space_bytes: float
|
freed_space_bytes: float
|
||||||
|
|
||||||
|
|
||||||
class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]):
|
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||||
"""Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup.
|
"""Provides a disk-backed storage for arbitrary python objects.
|
||||||
|
|
||||||
:param output_folder: The folder where the objects will be stored
|
:param output_folder: The folder where the objects will be stored
|
||||||
|
:param delete_on_startup: If True, all objects in the output folder will be deleted on startup
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, output_dir: Path):
|
def __init__(self, output_dir: Path, delete_on_startup: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._output_dir = output_dir
|
self._output_dir = output_dir
|
||||||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._delete_on_startup = delete_on_startup
|
||||||
self.__obj_class_name: Optional[str] = None
|
self.__obj_class_name: Optional[str] = None
|
||||||
|
|
||||||
def start(self, invoker: "Invoker") -> None:
|
def start(self, invoker: "Invoker") -> None:
|
||||||
self._invoker = invoker
|
self._invoker = invoker
|
||||||
delete_all_result = self._delete_all()
|
|
||||||
if delete_all_result.deleted_count > 0:
|
if self._delete_on_startup:
|
||||||
freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2)
|
delete_all_result = self._delete_all()
|
||||||
self._invoker.services.logger.info(
|
if delete_all_result.deleted_count > 0:
|
||||||
f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)"
|
freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2)
|
||||||
)
|
self._invoker.services.logger.info(
|
||||||
|
f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)"
|
||||||
|
)
|
||||||
|
|
||||||
def load(self, name: str) -> T:
|
def load(self, name: str) -> T:
|
||||||
file_path = self._get_path(name)
|
file_path = self._get_path(name)
|
@ -1,11 +1,14 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||||
from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||||
|
|
||||||
|
|
||||||
@ -14,22 +17,31 @@ class MockDataclass:
|
|||||||
foo: str
|
foo: str
|
||||||
|
|
||||||
|
|
||||||
|
def count_files(path: Path):
|
||||||
|
return len(list(path.iterdir()))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def obj_serializer(tmp_path: Path):
|
def obj_serializer(tmp_path: Path):
|
||||||
return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path)
|
return ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def fwd_cache(tmp_path: Path):
|
def fwd_cache(tmp_path: Path):
|
||||||
return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2)
|
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path):
|
@pytest.fixture
|
||||||
obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path)
|
def mock_invoker_with_logger():
|
||||||
|
return Mock(Invoker, services=Mock(logger=Mock(Logger)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_disk_initializes(tmp_path: Path):
|
||||||
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||||
assert obj_serializer._output_dir == tmp_path
|
assert obj_serializer._output_dir == tmp_path
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]):
|
def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||||
@ -39,7 +51,7 @@ def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEph
|
|||||||
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]):
|
def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
assert obj_serializer.load(obj_1_name).foo == "bar"
|
assert obj_serializer.load(obj_1_name).foo == "bar"
|
||||||
@ -52,7 +64,7 @@ def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEph
|
|||||||
obj_serializer.load("nonexistent_object_name")
|
obj_serializer.load("nonexistent_object_name")
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]):
|
def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
|
|
||||||
@ -64,7 +76,7 @@ def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerE
|
|||||||
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]):
|
def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
|
|
||||||
@ -78,8 +90,30 @@ def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSeriali
|
|||||||
assert delete_all_result.deleted_count == 2
|
assert delete_all_result.deleted_count == 2
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path):
|
def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
|
||||||
obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path)
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||||
|
assert obj_serializer._delete_on_startup is False
|
||||||
|
|
||||||
|
obj_1 = MockDataclass(foo="bar")
|
||||||
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
|
|
||||||
|
obj_serializer.start(mock_invoker_with_logger)
|
||||||
|
assert Path(tmp_path, obj_1_name).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
|
||||||
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True)
|
||||||
|
assert obj_serializer._delete_on_startup is True
|
||||||
|
|
||||||
|
obj_1 = MockDataclass(foo="bar")
|
||||||
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
|
|
||||||
|
obj_serializer.start(mock_invoker_with_logger)
|
||||||
|
assert not Path(tmp_path, obj_1_name).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||||
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||||
|
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
@ -88,17 +122,17 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path):
|
|||||||
assert obj_1_loaded.foo == "bar"
|
assert obj_1_loaded.foo == "bar"
|
||||||
assert obj_1_name.startswith("MockDataclass_")
|
assert obj_1_name.startswith("MockDataclass_")
|
||||||
|
|
||||||
obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path)
|
obj_serializer = ObjectSerializerDisk[int](tmp_path)
|
||||||
obj_2_name = obj_serializer.save(9001)
|
obj_2_name = obj_serializer.save(9001)
|
||||||
assert obj_serializer.load(obj_2_name) == 9001
|
assert obj_serializer.load(obj_2_name) == 9001
|
||||||
assert obj_2_name.startswith("int_")
|
assert obj_2_name.startswith("int_")
|
||||||
|
|
||||||
obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path)
|
obj_serializer = ObjectSerializerDisk[str](tmp_path)
|
||||||
obj_3_name = obj_serializer.save("foo")
|
obj_3_name = obj_serializer.save("foo")
|
||||||
assert obj_serializer.load(obj_3_name) == "foo"
|
assert obj_serializer.load(obj_3_name) == "foo"
|
||||||
assert obj_3_name.startswith("str_")
|
assert obj_3_name.startswith("str_")
|
||||||
|
|
||||||
obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path)
|
obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path)
|
||||||
obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3]))
|
obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3]))
|
||||||
obj_4_loaded = obj_serializer.load(obj_4_name)
|
obj_4_loaded = obj_serializer.load(obj_4_name)
|
||||||
assert isinstance(obj_4_loaded, torch.Tensor)
|
assert isinstance(obj_4_loaded, torch.Tensor)
|
||||||
@ -106,7 +140,7 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path):
|
|||||||
assert obj_4_name.startswith("Tensor_")
|
assert obj_4_name.startswith("Tensor_")
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]):
|
def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||||
fwd_cache = ObjectSerializerForwardCache(obj_serializer)
|
fwd_cache = ObjectSerializerForwardCache(obj_serializer)
|
||||||
assert fwd_cache._underlying_storage == obj_serializer
|
assert fwd_cache._underlying_storage == obj_serializer
|
||||||
|
|
Loading…
Reference in New Issue
Block a user