refactor(minor): Image File Storage

This commit is contained in:
blessedcoolant 2023-06-15 01:58:58 +12:00
parent 88e8e3977b
commit b4c998a9ae

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
@ -76,28 +75,26 @@ class ImageFileStorageBase(ABC):
class DiskImageFileStorage(ImageFileStorageBase): class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk""" """Stores images on disk"""
__output_folder: str __output_folder: str | Path
__cache_ids: Queue # TODO: this is an incredibly naive cache __cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType] __cache: Dict[Path, PILImageType]
__max_cache_size: int __max_cache_size: int
def __init__(self, output_folder: str): def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True) self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# TODO: don't hard-code. get/save/delete should maybe take subpath? # Validate required output folders at launch
Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True) self.validate_storage_folders()
Path(os.path.join(output_folder, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_name: str) -> PILImageType: def get(self, image_name: str) -> PILImageType:
try: try:
image_path = self.get_path(image_name) image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path) cache_item = self.__get_cache(image_path)
if cache_item: if cache_item:
return cache_item return cache_item
@ -116,6 +113,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
self.validate_storage_folders()
image_path = self.get_path(image_name) image_path = self.get_path(image_name)
if metadata is not None: if metadata is not None:
@ -137,10 +135,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
def delete(self, image_name: str) -> None: def delete(self, image_name: str) -> None:
try: try:
basename = os.path.basename(image_name) image_path = self.get_path(image_name)
image_path = self.get_path(basename)
if os.path.exists(image_path): if image_path.exists():
send2trash(image_path) send2trash(image_path)
if image_path in self.__cache: if image_path in self.__cache:
del self.__cache[image_path] del self.__cache[image_path]
@ -148,7 +145,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, True) thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path): if thumbnail_path.exists():
send2trash(thumbnail_path) send2trash(thumbnail_path)
if thumbnail_path in self.__cache: if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path] del self.__cache[thumbnail_path]
@ -156,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
# strip out any relative path shenanigans path = self.__output_folder / image_name
basename = os.path.basename(image_name)
if thumbnail: if thumbnail:
thumbnail_name = get_thumbnail_name(basename) thumbnail_name = get_thumbnail_name(image_name)
path = os.path.join( path = self.__thumbnails_folder / thumbnail_name
self.__output_folder,
"thumbnails",
thumbnail_name,
)
else:
path = os.path.join(self.__output_folder, basename)
abspath = os.path.abspath(path) return path
return abspath def validate_path(self, path: str | Path) -> bool:
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail.""" """Validates the path given for an image or thumbnail."""
try: path = path if isinstance(path, Path) else Path(path)
os.stat(path) return path.exists()
return True
except:
return False
def __get_cache(self, image_name: str) -> PILImageType | None: def validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType): def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache: if not image_name in self.__cache:
self.__cache[image_name] = image self.__cache[image_name] = image
self.__cache_ids.put( self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size: if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get() cache_id = self.__cache_ids.get()
if cache_id in self.__cache: if cache_id in self.__cache: