mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(minor): Latent Disk Storage
This commit is contained in:
parent
b4c998a9ae
commit
587297878a
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
class DiskLatentsStorage(LatentsStorageBase):
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
"""Stores latents in a folder on disk without caching"""
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
__output_folder: str
|
__output_folder: str | Path
|
||||||
|
|
||||||
def __init__(self, output_folder: str):
|
def __init__(self, output_folder: str | Path):
|
||||||
self.__output_folder = output_folder
|
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
def get(self, name: str) -> torch.Tensor:
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
return torch.load(latent_path)
|
return torch.load(latent_path)
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
torch.save(data, latent_path)
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
def delete(self, name: str) -> None:
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
os.remove(latent_path)
|
latent_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def get_path(self, name: str) -> str:
|
def get_path(self, name: str) -> Path:
|
||||||
return os.path.join(self.__output_folder, name)
|
return self.__output_folder / name
|
||||||
|
|
Loading…
Reference in New Issue
Block a user