Switch ModelSearch from os.walk to os.scandir

This commit is contained in:
Brandon Rising 2024-02-22 20:36:42 -05:00 committed by Brandon
parent 8c6860a2c5
commit a8c3efd98a

View File

@ -25,6 +25,7 @@ from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from typing import Callable, Optional, Set, Union
from invokeai.app.services.config import InvokeAIAppConfig
from pydantic import BaseModel, Field
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
"""
models_found: Set[Path] = Field(default_factory=set)
scanned_dirs: Set[Path] = Field(default_factory=set)
pruned_paths: Set[Path] = Field(default_factory=set)
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
def search_started(self) -> None:
self.models_found = set()
self.scanned_dirs = set()
self.pruned_paths = set()
if self.on_search_started:
self.on_search_started(self._directory)
@ -139,53 +137,51 @@ class ModelSearch(ModelSearchBase):
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)
if not self._directory.is_absolute():
self._directory = self.config.models_path / self._directory
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
self.search_completed()
return self.models_found
def _walk_directory(self, path: Union[Path, str]) -> None:
for root, dirs, files in os.walk(path, followlinks=True):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self.pruned_paths.add(Path(root))
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
continue
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
absolute_path = Path(path)
if len(absolute_path.parts) - len(self._directory.parts) > max_depth \
or not absolute_path.exists() \
or absolute_path.parent in self.models_found:
return
entries = os.scandir(absolute_path)
entries = [entry for entry in entries if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_names = [entry.name for entry in entries if entry.is_file()]
if any(
x in file_names
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
try:
self.model_found(absolute_path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
finally:
return
self.stats.items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path.parent in self.scanned_dirs:
self.scanned_dirs.add(path)
continue
if any(
(path / x).exists()
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for n in file_names:
if any([n.endswith(suffix) for suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}]):
try:
self.model_found(absolute_path / n)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self.scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for d in dirs:
self._walk_directory(absolute_path / d)