refactor(minor): Image & Latent File Storage (#3538)

- `DiskImageStorage` and `DiskLatentsStorage` have now both been updated
to exclusively work with `Path` objects and not rely on the `os` lib to
handle pathing related functions.
- We now also validate the existence of the required image output
folders and latent output folders to ensure that the app does not break
in case the required folders get tampered with mid-session.
- Just overall general cleanup.

Tested it. Don't seem to be any thing breaking.
This commit is contained in:
blessedcoolant 2023-06-15 02:43:27 +12:00 committed by GitHub
commit 70ece4364c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 49 deletions

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@ -76,28 +75,26 @@ class ImageFileStorageBase(ABC):
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__cache: Dict[Path, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
def __init__(self, output_folder: str | Path):
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# TODO: don't hard-code. get/save/delete should maybe take subpath?
Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True)
Path(os.path.join(output_folder, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
# Validate required output folders at launch
self.__validate_storage_folders()
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@ -116,6 +113,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_size: int = 256,
) -> None:
try:
self.__validate_storage_folders()
image_path = self.get_path(image_name)
if metadata is not None:
@ -137,10 +135,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
def delete(self, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(basename)
image_path = self.get_path(image_name)
if os.path.exists(image_path):
if image_path.exists():
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
@ -148,7 +145,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path):
if thumbnail_path.exists():
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
@ -156,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder,
"thumbnails",
thumbnail_name,
)
else:
path = os.path.join(self.__output_folder, basename)
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
abspath = os.path.abspath(path)
return path
return abspath
def validate_path(self, path: str) -> bool:
def validate_path(self, path: str | Path) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: str) -> PILImageType | None:
def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
__output_folder: str | Path
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
latent_path.unlink()
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)
def get_path(self, name: str) -> Path:
return self.__output_folder / name