mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch ModelSearch from os.walk to os.scandir
This commit is contained in:
parent
8c6860a2c5
commit
a8c3efd98a
@ -25,6 +25,7 @@ from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
|
||||
"""
|
||||
|
||||
models_found: Set[Path] = Field(default_factory=set)
|
||||
scanned_dirs: Set[Path] = Field(default_factory=set)
|
||||
pruned_paths: Set[Path] = Field(default_factory=set)
|
||||
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
self.scanned_dirs = set()
|
||||
self.pruned_paths = set()
|
||||
if self.on_search_started:
|
||||
self.on_search_started(self._directory)
|
||||
|
||||
@ -139,53 +137,51 @@ class ModelSearch(ModelSearchBase):
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = Path(directory)
|
||||
if not self._directory.is_absolute():
|
||||
self._directory = self.config.models_path / self._directory
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
self.search_completed()
|
||||
return self.models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str]) -> None:
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self.pruned_paths.add(Path(root))
|
||||
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
||||
continue
|
||||
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
|
||||
absolute_path = Path(path)
|
||||
if len(absolute_path.parts) - len(self._directory.parts) > max_depth \
|
||||
or not absolute_path.exists() \
|
||||
or absolute_path.parent in self.models_found:
|
||||
return
|
||||
entries = os.scandir(absolute_path)
|
||||
entries = [entry for entry in entries if not entry.name.startswith(".")]
|
||||
dirs = [entry for entry in entries if entry.is_dir()]
|
||||
file_names = [entry.name for entry in entries if entry.is_file()]
|
||||
if any(
|
||||
x in file_names
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
):
|
||||
try:
|
||||
self.model_found(absolute_path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
finally:
|
||||
return
|
||||
|
||||
self.stats.items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path.parent in self.scanned_dirs:
|
||||
self.scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
(path / x).exists()
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
):
|
||||
self.scanned_dirs.add(path)
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
for n in file_names:
|
||||
if any([n.endswith(suffix) for suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}]):
|
||||
try:
|
||||
self.model_found(absolute_path / n)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self.scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
for d in dirs:
|
||||
self._walk_directory(absolute_path / d)
|
||||
|
Loading…
Reference in New Issue
Block a user