feat(backend): clear latents files on startup

Adds logic to `DiskLatentsStorage.start()` to empty the latents folder on startup.

Adds start and stop methods to `ForwardCacheLatentsStorage`. This is required for `DiskLatentsStorage.start()` to be called, due to how this particular service breaks the direct DI pattern, wrapping the underlying storage with a cache.
This commit is contained in:
psychedelicious 2023-11-30 19:57:29 +11:00 committed by Kent Keirsey
parent cff6600ded
commit 1fd6666682
2 changed files with 33 additions and 0 deletions

View File

@ -5,6 +5,8 @@ from typing import Union
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
@ -17,6 +19,23 @@ class DiskLatentsStorage(LatentsStorageBase):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
# Delete all latents files on startup
deleted_latents_count = 0
freed_space = 0
for latents_file in Path(self.__output_folder).glob("*"):
if latents_file.is_file():
freed_space += latents_file.stat().st_size
deleted_latents_count += 1
latents_file.unlink()
if deleted_latents_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_latents_count} latents files, freeing {freed_space_in_mb}MB"
)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)

View File

@ -5,6 +5,8 @@ from typing import Dict, Optional
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
@ -23,6 +25,18 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self.__underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self.__underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None: