diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index f7ef2e049d..4bfaa6ce7d 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -25,6 +25,7 @@ from abc import ABC, abstractmethod from logging import Logger from pathlib import Path from typing import Callable, Optional, Set, Union +from invokeai.app.services.config import InvokeAIAppConfig from pydantic import BaseModel, Field @@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase): """ models_found: Set[Path] = Field(default_factory=set) - scanned_dirs: Set[Path] = Field(default_factory=set) - pruned_paths: Set[Path] = Field(default_factory=set) + config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() def search_started(self) -> None: self.models_found = set() - self.scanned_dirs = set() - self.pruned_paths = set() if self.on_search_started: self.on_search_started(self._directory) @@ -139,53 +137,51 @@ class ModelSearch(ModelSearchBase): def search(self, directory: Union[Path, str]) -> Set[Path]: self._directory = Path(directory) + if not self._directory.is_absolute(): + self._directory = self.config.models_path / self._directory self.stats = SearchStats() # zero out self.search_started() # This will initialize _models_found to empty self._walk_directory(directory) self.search_completed() return self.models_found - def _walk_directory(self, path: Union[Path, str]) -> None: - for root, dirs, files in os.walk(path, followlinks=True): - # don't descend into directories that start with a "." - # to avoid the Mac .DS_STORE issue. - if str(Path(root).name).startswith("."): - self.pruned_paths.add(Path(root)) - if any(Path(root).is_relative_to(x) for x in self.pruned_paths): - continue + def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None: + absolute_path = Path(path) + if len(absolute_path.parts) - len(self._directory.parts) > max_depth \ + or not absolute_path.exists() \ + or absolute_path.parent in self.models_found: + return + entries = os.scandir(absolute_path) + entries = [entry for entry in entries if not entry.name.startswith(".")] + dirs = [entry for entry in entries if entry.is_dir()] + file_names = [entry.name for entry in entries if entry.is_file()] + if any( + x in file_names + for x in [ + "config.json", + "model_index.json", + "learned_embeds.bin", + "pytorch_lora_weights.bin", + "image_encoder.txt", + ] + ): + try: + self.model_found(absolute_path) + except KeyboardInterrupt: + raise + except Exception as e: + self.logger.warning(str(e)) + finally: + return - self.stats.items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path.parent in self.scanned_dirs: - self.scanned_dirs.add(path) - continue - if any( - (path / x).exists() - for x in [ - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "image_encoder.txt", - ] - ): - self.scanned_dirs.add(path) - try: - self.model_found(path) - except KeyboardInterrupt: - raise - except Exception as e: - self.logger.warning(str(e)) + for n in file_names: + if any([n.endswith(suffix) for suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}]): + try: + self.model_found(absolute_path / n) + except KeyboardInterrupt: + raise + except Exception as e: + self.logger.warning(str(e)) - for f in files: - path = Path(root) / f - if path.parent in self.scanned_dirs: - continue - if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: - try: - self.model_found(path) - except KeyboardInterrupt: - raise - except Exception as e: - self.logger.warning(str(e)) + for d in dirs: + self._walk_directory(absolute_path / d)