From 587297878a7f76cbfcf4e5fbdd43a5dafea33201 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Thu, 15 Jun 2023 02:16:09 +1200 Subject: [PATCH] refactor(minor): Latent Disk Storage --- invokeai/app/services/latent_storage.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/latent_storage.py b/invokeai/app/services/latent_storage.py index 519c254087..17d35d7c33 100644 --- a/invokeai/app/services/latent_storage.py +++ b/invokeai/app/services/latent_storage.py @@ -1,6 +1,5 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -import os from abc import ABC, abstractmethod from pathlib import Path from queue import Queue @@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): class DiskLatentsStorage(LatentsStorageBase): """Stores latents in a folder on disk without caching""" - __output_folder: str + __output_folder: str | Path - def __init__(self, output_folder: str): - self.__output_folder = output_folder - Path(output_folder).mkdir(parents=True, exist_ok=True) + def __init__(self, output_folder: str | Path): + self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) + self.__output_folder.mkdir(parents=True, exist_ok=True) def get(self, name: str) -> torch.Tensor: latent_path = self.get_path(name) return torch.load(latent_path) def save(self, name: str, data: torch.Tensor) -> None: + self.__output_folder.mkdir(parents=True, exist_ok=True) latent_path = self.get_path(name) torch.save(data, latent_path) def delete(self, name: str) -> None: latent_path = self.get_path(name) - os.remove(latent_path) + latent_path.unlink() + - def get_path(self, name: str) -> str: - return os.path.join(self.__output_folder, name) + def get_path(self, name: str) -> Path: + return self.__output_folder / name \ No newline at end of file