# Copyright 2023, Lincoln D. Stein and the InvokeAI Team """ Abstract base class and implementation for recursive directory search for models. Example usage: ``` from invokeai.backend.model_manager import ModelSearch, ModelProbe def find_main_models(model: Path) -> bool: info = ModelProbe.probe(model) if info.model_type == 'main' and info.base_type == 'sd-1': return True else: return False search = ModelSearch(on_model_found=report_it) found = search.search('/tmp/models') print(found) # list of matching model paths print(search.stats) # search stats ``` """ import os from dataclasses import dataclass from pathlib import Path from typing import Callable, Optional from invokeai.backend.util.logging import InvokeAILogger @dataclass class SearchStats: """Statistics about the search. Attributes: items_scanned: number of items scanned models_found: number of models found models_filtered: number of models that passed the filter """ items_scanned = 0 models_found = 0 models_filtered = 0 class ModelSearch: """Searches a directory tree for models, using a callback to filter the results. Usage: search = ModelSearch() search.on_model_found = lambda path : 'anime' in path.as_posix() found = search.search(Path('/tmp/models1')) """ def __init__( self, on_search_started: Optional[Callable[[Path], None]] = None, on_model_found: Optional[Callable[[Path], bool]] = None, on_search_completed: Optional[Callable[[set[Path]], None]] = None, ) -> None: """Create a new ModelSearch object. Args: on_search_started: callback to be invoked when the search starts on_model_found: callback to be invoked when a model is found. The callback should return True if the model should be included in the results. on_search_completed: callback to be invoked when the search is completed """ self.stats = SearchStats() self.logger = InvokeAILogger.get_logger() self.on_search_started = on_search_started self.on_model_found = on_model_found self.on_search_completed = on_search_completed self.models_found: set[Path] = set() def search_started(self) -> None: self.models_found = set() if self.on_search_started: self.on_search_started(self._directory) def model_found(self, model: Path) -> None: self.stats.models_found += 1 if self.on_model_found is None or self.on_model_found(model): self.stats.models_filtered += 1 self.models_found.add(model) def search_completed(self) -> None: if self.on_search_completed is not None: self.on_search_completed(self.models_found) def search(self, directory: Path) -> set[Path]: self._directory = Path(directory) self._directory = self._directory.resolve() self.stats = SearchStats() # zero out self.search_started() # This will initialize _models_found to empty self._walk_directory(self._directory) self.search_completed() return self.models_found def _walk_directory(self, path: Path, max_depth: int = 20) -> None: """Recursively walk the directory tree, looking for models.""" absolute_path = Path(path) if ( len(absolute_path.parts) - len(self._directory.parts) > max_depth or not absolute_path.exists() or absolute_path.parent in self.models_found ): return entries = os.scandir(absolute_path.as_posix()) entries = [entry for entry in entries if not entry.name.startswith(".")] dirs = [entry for entry in entries if entry.is_dir()] file_names = [entry.name for entry in entries if entry.is_file()] if any( x in file_names for x in [ "config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin", "image_encoder.txt", ] ): try: self.model_found(absolute_path) return except KeyboardInterrupt: raise except Exception as e: self.logger.warning(str(e)) return for n in file_names: if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")): try: self.model_found(absolute_path / n) except KeyboardInterrupt: raise except Exception as e: self.logger.warning(str(e)) for d in dirs: self._walk_directory(absolute_path / d)