# Copyright 2023, Lincoln D. Stein and the InvokeAI Team """ Abstract base class and implementation for recursive directory search for models. Example usage: ``` from invokeai.backend.model_manager import ModelSearch, ModelProbe def find_main_models(model: Path) -> bool: info = ModelProbe.probe(model) if info.model_type == 'main' and info.base_type == 'sd-1': return True else: return False search = ModelSearch(on_model_found=report_it) found = search.search('/tmp/models') print(found) # list of matching model paths print(search.stats) # search stats ``` """ import os from abc import ABC, abstractmethod from pathlib import Path from typing import Callable, Optional, Set, Union from pydantic import BaseModel, Field from invokeai.backend.util.logging import InvokeAILogger default_logger = InvokeAILogger.get_logger() class SearchStats(BaseModel): items_scanned: int = 0 models_found: int = 0 models_filtered: int = 0 class ModelSearchBase(ABC, BaseModel): """ Abstract directory traversal model search class Usage: search = ModelSearchBase( on_search_started = search_started_callback, on_search_completed = search_completed_callback, on_model_found = model_found_callback, ) models_found = search.search('/path/to/directory') """ # fmt: off on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221 on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221 # fmt: on class Config: arbitrary_types_allowed = True @abstractmethod def search_started(self) -> None: """ Called before the scan starts. Passes the root search directory to the Callable `on_search_started`. """ pass @abstractmethod def model_found(self, model: Path) -> None: """ Called when a model is found during search. :param model: Model to process - could be a directory or checkpoint. Passes the model's Path to the Callable `on_model_found`. This Callable receives the path to the model and returns a boolean to indicate whether the model should be returned in the search results. """ pass @abstractmethod def search_completed(self) -> None: """ Called before the scan starts. Passes the Set of found model Paths to the Callable `on_search_completed`. """ pass @abstractmethod def search(self, directory: Union[Path, str]) -> Set[Path]: """ Recursively search for models in `directory` and return a set of model paths. If provided, the `on_search_started`, `on_model_found` and `on_search_completed` Callables will be invoked during the search. """ pass class ModelSearch(ModelSearchBase): """ Implementation of ModelSearch with callbacks. Usage: search = ModelSearch() search.model_found = lambda path : 'anime' in path.as_posix() found = search.list_models(['/tmp/models1','/tmp/models2']) # returns all models that have 'anime' in the path """ models_found: Set[Path] = Field(default=None) scanned_dirs: Set[Path] = Field(default=None) pruned_paths: Set[Path] = Field(default=None) def search_started(self) -> None: self.models_found = set() self.scanned_dirs = set() self.pruned_paths = set() if self.on_search_started: self.on_search_started(self._directory) def model_found(self, model: Path) -> None: self.stats.models_found += 1 if not self.on_model_found or self.on_model_found(model): self.stats.models_filtered += 1 self.models_found.add(model) def search_completed(self) -> None: if self.on_search_completed: self.on_search_completed(self._models_found) def search(self, directory: Union[Path, str]) -> Set[Path]: self._directory = Path(directory) self.stats = SearchStats() # zero out self.search_started() # This will initialize _models_found to empty self._walk_directory(directory) self.search_completed() return self.models_found def _walk_directory(self, path: Union[Path, str]) -> None: for root, dirs, files in os.walk(path, followlinks=True): # don't descend into directories that start with a "." # to avoid the Mac .DS_STORE issue. if str(Path(root).name).startswith("."): self.pruned_paths.add(Path(root)) if any(Path(root).is_relative_to(x) for x in self.pruned_paths): continue self.stats.items_scanned += len(dirs) + len(files) for d in dirs: path = Path(root) / d if path.parent in self.scanned_dirs: self.scanned_dirs.add(path) continue if any( (path / x).exists() for x in [ "config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin", "image_encoder.txt", ] ): self.scanned_dirs.add(path) try: self.model_found(path) except KeyboardInterrupt: raise except Exception as e: self.logger.warning(str(e)) for f in files: path = Path(root) / f if path.parent in self.scanned_dirs: continue if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: try: self.model_found(path) except KeyboardInterrupt: raise except Exception as e: self.logger.warning(str(e))