mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk
Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`. The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`. The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped. In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves. This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often. Tests updated.
This commit is contained in:
parent
199ddd6623
commit
b7ffd36cc6
@ -91,10 +91,10 @@ class ApiDependencies:
|
|||||||
images = ImageService()
|
images = ImageService()
|
||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
tensors = ObjectSerializerForwardCache(
|
tensors = ObjectSerializerForwardCache(
|
||||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True)
|
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||||
)
|
)
|
||||||
conditioning = ObjectSerializerForwardCache(
|
conditioning = ObjectSerializerForwardCache(
|
||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
model_record_service = ModelRecordServiceSQL(db=db)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import tempfile
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -23,28 +24,24 @@ class DeleteAllResult:
|
|||||||
|
|
||||||
|
|
||||||
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||||
"""Provides a disk-backed storage for arbitrary python objects.
|
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||||
|
|
||||||
:param output_folder: The folder where the objects will be stored
|
:param output_dir: The folder where the serialized objects will be stored
|
||||||
:param delete_on_startup: If True, all objects in the output folder will be deleted on startup
|
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, output_dir: Path, delete_on_startup: bool = False):
|
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._output_dir = output_dir
|
self._ephemeral = ephemeral
|
||||||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
self._base_output_dir = output_dir
|
||||||
self._delete_on_startup = delete_on_startup
|
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
||||||
|
self._tempdir = (
|
||||||
|
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
||||||
|
)
|
||||||
|
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||||
self.__obj_class_name: Optional[str] = None
|
self.__obj_class_name: Optional[str] = None
|
||||||
|
|
||||||
def start(self, invoker: "Invoker") -> None:
|
|
||||||
if self._delete_on_startup:
|
|
||||||
delete_all_result = self._delete_all()
|
|
||||||
if delete_all_result.deleted_count > 0:
|
|
||||||
freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2)
|
|
||||||
invoker.services.logger.info(
|
|
||||||
f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, name: str) -> T:
|
def load(self, name: str) -> T:
|
||||||
file_path = self._get_path(name)
|
file_path = self._get_path(name)
|
||||||
try:
|
try:
|
||||||
@ -75,19 +72,14 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
|||||||
def _new_name(self) -> str:
|
def _new_name(self) -> str:
|
||||||
return f"{self._obj_class_name}_{uuid_string()}"
|
return f"{self._obj_class_name}_{uuid_string()}"
|
||||||
|
|
||||||
def _delete_all(self) -> DeleteAllResult:
|
def _tempdir_cleanup(self) -> None:
|
||||||
"""
|
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||||
Deletes all objects from disk.
|
if self._tempdir:
|
||||||
"""
|
self._tempdir.cleanup()
|
||||||
|
|
||||||
# We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have
|
def __del__(self) -> None:
|
||||||
# to manually clear them on startup anyways. This is a bit simpler and more reliable.
|
# In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd.
|
||||||
|
self._tempdir_cleanup()
|
||||||
|
|
||||||
deleted_count = 0
|
def stop(self, invoker: "Invoker") -> None:
|
||||||
freed_space = 0
|
self._tempdir_cleanup()
|
||||||
for file in Path(self._output_dir).glob("*"):
|
|
||||||
if file.is_file():
|
|
||||||
freed_space += file.stat().st_size
|
|
||||||
deleted_count += 1
|
|
||||||
file.unlink()
|
|
||||||
return DeleteAllResult(deleted_count, freed_space)
|
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from logging import Logger
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||||
@ -31,11 +29,6 @@ def fwd_cache(tmp_path: Path):
|
|||||||
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
|
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_invoker_with_logger():
|
|
||||||
return Mock(Invoker, services=Mock(logger=Mock(Logger)))
|
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_disk_initializes(tmp_path: Path):
|
def test_obj_serializer_disk_initializes(tmp_path: Path):
|
||||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||||
assert obj_serializer._output_dir == tmp_path
|
assert obj_serializer._output_dir == tmp_path
|
||||||
@ -76,39 +69,33 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa
|
|||||||
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
|
||||||
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||||
|
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
|
||||||
|
assert obj_serializer._base_output_dir == tmp_path
|
||||||
|
assert obj_serializer._output_dir != tmp_path
|
||||||
|
assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
|
||||||
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||||
|
tempdir_path = obj_serializer._output_dir
|
||||||
|
del obj_serializer
|
||||||
|
assert not tempdir_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
|
||||||
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||||
|
tempdir_path = obj_serializer._output_dir
|
||||||
|
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
|
||||||
|
assert not tempdir_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||||
|
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||||
obj_1 = MockDataclass(foo="bar")
|
obj_1 = MockDataclass(foo="bar")
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
obj_1_name = obj_serializer.save(obj_1)
|
||||||
|
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||||
obj_2 = MockDataclass(foo="bar")
|
|
||||||
obj_2_name = obj_serializer.save(obj_2)
|
|
||||||
|
|
||||||
delete_all_result = obj_serializer._delete_all()
|
|
||||||
|
|
||||||
assert not Path(obj_serializer._output_dir, obj_1_name).exists()
|
|
||||||
assert not Path(obj_serializer._output_dir, obj_2_name).exists()
|
|
||||||
assert delete_all_result.deleted_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
|
|
||||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
|
||||||
assert obj_serializer._delete_on_startup is False
|
|
||||||
|
|
||||||
obj_1 = MockDataclass(foo="bar")
|
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
|
||||||
|
|
||||||
obj_serializer.start(mock_invoker_with_logger)
|
|
||||||
assert Path(tmp_path, obj_1_name).exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker):
|
|
||||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True)
|
|
||||||
assert obj_serializer._delete_on_startup is True
|
|
||||||
|
|
||||||
obj_1 = MockDataclass(foo="bar")
|
|
||||||
obj_1_name = obj_serializer.save(obj_1)
|
|
||||||
|
|
||||||
obj_serializer.start(mock_invoker_with_logger)
|
|
||||||
assert not Path(tmp_path, obj_1_name).exists()
|
assert not Path(tmp_path, obj_1_name).exists()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user