Switch ModelSearch from os.walk to os.scandir

This commit is contained in:
Brandon Rising 2024-02-22 20:36:42 -05:00 committed by Brandon
parent 8c6860a2c5
commit a8c3efd98a

View File

@ -25,6 +25,7 @@ from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Set, Union from typing import Callable, Optional, Set, Union
from invokeai.app.services.config import InvokeAIAppConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
""" """
models_found: Set[Path] = Field(default_factory=set) models_found: Set[Path] = Field(default_factory=set)
scanned_dirs: Set[Path] = Field(default_factory=set) config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
pruned_paths: Set[Path] = Field(default_factory=set)
def search_started(self) -> None: def search_started(self) -> None:
self.models_found = set() self.models_found = set()
self.scanned_dirs = set()
self.pruned_paths = set()
if self.on_search_started: if self.on_search_started:
self.on_search_started(self._directory) self.on_search_started(self._directory)
@ -139,53 +137,51 @@ class ModelSearch(ModelSearchBase):
def search(self, directory: Union[Path, str]) -> Set[Path]: def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory) self._directory = Path(directory)
if not self._directory.is_absolute():
self._directory = self.config.models_path / self._directory
self.stats = SearchStats() # zero out self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory) self._walk_directory(directory)
self.search_completed() self.search_completed()
return self.models_found return self.models_found
def _walk_directory(self, path: Union[Path, str]) -> None: def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
for root, dirs, files in os.walk(path, followlinks=True): absolute_path = Path(path)
# don't descend into directories that start with a "." if len(absolute_path.parts) - len(self._directory.parts) > max_depth \
# to avoid the Mac .DS_STORE issue. or not absolute_path.exists() \
if str(Path(root).name).startswith("."): or absolute_path.parent in self.models_found:
self.pruned_paths.add(Path(root)) return
if any(Path(root).is_relative_to(x) for x in self.pruned_paths): entries = os.scandir(absolute_path)
continue entries = [entry for entry in entries if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_names = [entry.name for entry in entries if entry.is_file()]
if any(
x in file_names
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
try:
self.model_found(absolute_path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
finally:
return
self.stats.items_scanned += len(dirs) + len(files) for n in file_names:
for d in dirs: if any([n.endswith(suffix) for suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}]):
path = Path(root) / d try:
if path.parent in self.scanned_dirs: self.model_found(absolute_path / n)
self.scanned_dirs.add(path) except KeyboardInterrupt:
continue raise
if any( except Exception as e:
(path / x).exists() self.logger.warning(str(e))
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files: for d in dirs:
path = Path(root) / f self._walk_directory(absolute_path / d)
if path.parent in self.scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))