make download and convert cache keys safe for filename length

This commit is contained in:
Lincoln Stein
2024-04-28 12:24:36 -04:00
parent bb04f496e0
commit a26667d3ca
4 changed files with 36 additions and 10 deletions

View File

@ -7,6 +7,7 @@ from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:

View File

@ -19,7 +19,6 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
# TO DO: The loader is not thread safe!
@ -85,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
except IndexError:
pass
cache_path: Path = self._convert_cache.cache_path(slugify(model_path))
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:

View File

@ -18,7 +18,8 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
"""
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
underscores, or hyphens. Replace slashes with underscores.
Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
@ -29,10 +30,17 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[^\w\s-]", "", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")
def safe_filename(directory: Path, value: str) -> str:
"""Make a string safe to use as a filename."""
escaped_string = slugify(value)
max_name_length = os.pathconf(directory, "PC_NAME_MAX")
return escaped_string[len(escaped_string) - max_name_length :]
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).