mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(minor): Image & Latent File Storage (#3538)
- `DiskImageStorage` and `DiskLatentsStorage` have now both been updated to exclusively work with `Path` objects and not rely on the `os` lib to handle pathing related functions. - We now also validate the existence of the required image output folders and latent output folders to ensure that the app does not break in case the required folders get tampered with mid-session. - Just overall general cleanup. Tested it. Don't seem to be any thing breaking.
This commit is contained in:
commit
70ece4364c
@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -76,28 +75,26 @@ class ImageFileStorageBase(ABC):
|
|||||||
class DiskImageFileStorage(ImageFileStorageBase):
|
class DiskImageFileStorage(ImageFileStorageBase):
|
||||||
"""Stores images on disk"""
|
"""Stores images on disk"""
|
||||||
|
|
||||||
__output_folder: str
|
__output_folder: Path
|
||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
__cache: Dict[str, PILImageType]
|
__cache: Dict[Path, PILImageType]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
|
|
||||||
def __init__(self, output_folder: str):
|
def __init__(self, output_folder: str | Path):
|
||||||
self.__output_folder = output_folder
|
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
|
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
|
||||||
|
|
||||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
# Validate required output folders at launch
|
||||||
Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True)
|
self.__validate_storage_folders()
|
||||||
Path(os.path.join(output_folder, "thumbnails")).mkdir(
|
|
||||||
parents=True, exist_ok=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, image_name: str) -> PILImageType:
|
def get(self, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
|
|
||||||
cache_item = self.__get_cache(image_path)
|
cache_item = self.__get_cache(image_path)
|
||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
@ -116,6 +113,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
|
self.__validate_storage_folders()
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
@ -137,10 +135,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
def delete(self, image_name: str) -> None:
|
def delete(self, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
basename = os.path.basename(image_name)
|
image_path = self.get_path(image_name)
|
||||||
image_path = self.get_path(basename)
|
|
||||||
|
|
||||||
if os.path.exists(image_path):
|
if image_path.exists():
|
||||||
send2trash(image_path)
|
send2trash(image_path)
|
||||||
if image_path in self.__cache:
|
if image_path in self.__cache:
|
||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
@ -148,7 +145,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
if thumbnail_path.exists():
|
||||||
send2trash(thumbnail_path)
|
send2trash(thumbnail_path)
|
||||||
if thumbnail_path in self.__cache:
|
if thumbnail_path in self.__cache:
|
||||||
del self.__cache[thumbnail_path]
|
del self.__cache[thumbnail_path]
|
||||||
@ -156,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
raise ImageFileDeleteException from e
|
raise ImageFileDeleteException from e
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||||
# strip out any relative path shenanigans
|
path = self.__output_folder / image_name
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if thumbnail:
|
if thumbnail:
|
||||||
thumbnail_name = get_thumbnail_name(basename)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
path = os.path.join(
|
path = self.__thumbnails_folder / thumbnail_name
|
||||||
self.__output_folder,
|
|
||||||
"thumbnails",
|
|
||||||
thumbnail_name,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
path = os.path.join(self.__output_folder, basename)
|
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
return path
|
||||||
|
|
||||||
return abspath
|
def validate_path(self, path: str | Path) -> bool:
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates the path given for an image or thumbnail."""
|
"""Validates the path given for an image or thumbnail."""
|
||||||
try:
|
path = path if isinstance(path, Path) else Path(path)
|
||||||
os.stat(path)
|
return path.exists()
|
||||||
return True
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __get_cache(self, image_name: str) -> PILImageType | None:
|
def __validate_storage_folders(self) -> None:
|
||||||
|
"""Checks if the required output folders exist and create them if they don't"""
|
||||||
|
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
|
||||||
|
for folder in folders:
|
||||||
|
folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def __get_cache(self, image_name: Path) -> PILImageType | None:
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
def __set_cache(self, image_name: str, image: PILImageType):
|
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||||
if not image_name in self.__cache:
|
if not image_name in self.__cache:
|
||||||
self.__cache[image_name] = image
|
self.__cache[image_name] = image
|
||||||
self.__cache_ids.put(
|
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||||
image_name
|
|
||||||
) # TODO: this should refresh position for LRU cache
|
|
||||||
if len(self.__cache) > self.__max_cache_size:
|
if len(self.__cache) > self.__max_cache_size:
|
||||||
cache_id = self.__cache_ids.get()
|
cache_id = self.__cache_ids.get()
|
||||||
if cache_id in self.__cache:
|
if cache_id in self.__cache:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
class DiskLatentsStorage(LatentsStorageBase):
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
"""Stores latents in a folder on disk without caching"""
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
__output_folder: str
|
__output_folder: str | Path
|
||||||
|
|
||||||
def __init__(self, output_folder: str):
|
def __init__(self, output_folder: str | Path):
|
||||||
self.__output_folder = output_folder
|
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
def get(self, name: str) -> torch.Tensor:
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
return torch.load(latent_path)
|
return torch.load(latent_path)
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
torch.save(data, latent_path)
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
def delete(self, name: str) -> None:
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
os.remove(latent_path)
|
latent_path.unlink()
|
||||||
|
|
||||||
def get_path(self, name: str) -> str:
|
|
||||||
return os.path.join(self.__output_folder, name)
|
def get_path(self, name: str) -> Path:
|
||||||
|
return self.__output_folder / name
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user