refactor(minor): Latent Disk Storage

This commit is contained in:
blessedcoolant 2023-06-15 02:16:09 +12:00
parent b4c998a9ae
commit 587297878a

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase): class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching""" """Stores latents in a folder on disk without caching"""
__output_folder: str __output_folder: str | Path
def __init__(self, output_folder: str): def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
Path(output_folder).mkdir(parents=True, exist_ok=True) self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor: def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name) latent_path = self.get_path(name)
return torch.load(latent_path) return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None: def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name) latent_path = self.get_path(name)
torch.save(data, latent_path) torch.save(data, latent_path)
def delete(self, name: str) -> None: def delete(self, name: str) -> None:
latent_path = self.get_path(name) latent_path = self.get_path(name)
os.remove(latent_path) latent_path.unlink()
def get_path(self, name: str) -> str: def get_path(self, name: str) -> Path:
return os.path.join(self.__output_folder, name) return self.__output_folder / name