diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 88d333dc20..c139ec4c5f 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -16,3 +16,4 @@ from .config import ( # noqa F401 from .install import ModelInstall # noqa F401 from .probe import ModelProbe, InvalidModelException # noqa F401 from .storage import DuplicateModelException # noqa F401 +from .search import ModelSearch diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 942e65ad06..a83ea5d596 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -254,10 +254,10 @@ class ModelInstall(ModelInstallBase): self.unregister(id) def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 - search = ModelSearch() - search.model_found = self._scan_install if install else self._scan_register + callback = self._scan_install if install else self._scan_register + search = ModelSearch(on_model_found=callback) self._installed = set() - search.search([scan_dir]) + search.search(scan_dir) return list(self._installed) def garbage_collect(self) -> List[str]: # noqa D102 diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 7a7f062570..548a0b72cf 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -362,7 +362,7 @@ class LoRACheckpointProbe(CheckpointProbeBase): elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: - raise InvalidModelException(f"Unknown LoRA type: {self.model}") + raise InvalidModelException(f"Unsupported LoRA type: {self.model}") class TextualInversionCheckpointProbe(CheckpointProbeBase): diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 74bf296077..41e5dfc6e3 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -1,97 +1,108 @@ # Copyright 2023, Lincoln D. Stein and the InvokeAI Team """ -Abstract base class for recursive directory search for models. +Abstract base class and implementation for recursive directory search for models. + +Example usage: +``` + from invokeai.backend.model_manager import ModelSearch, ModelProbe + + def find_main_models(model: Path) -> bool: + info = ModelProbe.probe(model) + if info.model_type == 'main' and info.base_type == 'sd-1': + return True + else: + return False + + search = ModelSearch(on_model_found=report_it) + found = search.search('/tmp/models') + print(found) # list of matching model paths + print(search.stats) # search stats +``` """ import os from abc import ABC, abstractmethod -from typing import List, Set, Optional, Callable, Union, types +from typing import Set, Optional, Callable, Union, types from pathlib import Path -import invokeai.backend.util.logging as logger +from invokeai.backend.util.logging import InvokeAILogger +from pydantic import Field, BaseModel + +default_logger = InvokeAILogger.getLogger() -class ModelSearchBase(ABC): - """Hierarchical directory model search class""" +class SearchStats(BaseModel): + items_scanned: int = 0 + models_found: int = 0 + models_filtered: int = 0 - def __init__(self, logger: types.ModuleType = logger): - """ - Initialize a recursive model directory search. - :param directories: List of directory Paths to recurse through - :param logger: Logger to use - """ - self.logger = logger - self._items_scanned = 0 - self._models_found = 0 - self._scanned_dirs = set() - self._scanned_paths = set() - self._pruned_paths = set() + +class ModelSearchBase(ABC, BaseModel): + """ + Abstract directory traversal model search class + + Usage: + search = ModelSearchBase( + on_search_started = search_started_callback, + on_search_completed = search_completed_callback, + on_model_found = model_found_callback, + ) + models_found = search.search('/path/to/directory') + """ + + # fmt: off + on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221 + on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 + on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 + stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 + logger : InvokeAILogger = Field(default=default_logger, description="InvokeAILogger instance.") # noqa E221 + # fmt: on + + class Config: + underscore_attrs_are_private = True + arbitrary_types_allowed = True @abstractmethod - def on_search_started(self): + def search_started(self): """ Called before the scan starts. + + Passes the root search directory to the Callable `on_search_started`. """ pass @abstractmethod - def on_model_found(self, model: Path): + def model_found(self, model: Path): """ - Process a found model. Raise an exception if something goes wrong. + Called when a model is found during search. + :param model: Model to process - could be a directory or checkpoint. + + Passes the model's Path to the Callable `on_model_found`. + This Callable receives the path to the model and returns a boolean + to indicate whether the model should be returned in the search + results. """ pass @abstractmethod - def on_search_completed(self): + def search_completed(self): """ - Perform some activity when the scan is completed. May use instance - variables, items_scanned and models_found + Called before the scan starts. + + Passes the Set of found model Paths to the Callable `on_search_completed`. """ pass - def search(self, directories: List[Union[Path, str]]): - self.on_search_started() - for dir in directories: - self.walk_directory(dir) - self.on_search_completed() + @abstractmethod + def search(self, directory: Union[Path, str]) -> Set[Path]: + """ + Recursively search for models in `directory` and return a set of model paths. - def walk_directory(self, path: Union[Path, str]): - for root, dirs, files in os.walk(path, followlinks=True): - if str(Path(root).name).startswith("."): - self._pruned_paths.add(root) - if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): - continue - - self._items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in self._scanned_paths or path.parent in self._scanned_dirs: - self._scanned_dirs.add(path) - continue - if any( - [ - (path / x).exists() - for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"] - ] - ): - try: - self.on_model_found(path) - self._models_found += 1 - self._scanned_dirs.add(path) - except Exception as e: - self.logger.warning(str(e)) - - for f in files: - path = Path(root) / f - if path.parent in self._scanned_dirs: - continue - if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: - try: - self.on_model_found(path) - self._models_found += 1 - except Exception as e: - self.logger.warning(str(e)) + If provided, the `on_search_started`, `on_model_found` and `on_search_completed` + Callables will be invoked during the search. + """ + pass class ModelSearch(ModelSearchBase): @@ -104,35 +115,73 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - _model_set: Set[Path] - search_started: Callable[[Path], None] - search_completed: Callable[[Set[Path]], None] - model_found: Callable[[Path], bool] + _directory: Path = Field(default=None) + _models_found: Set[Path] = Field(default=None) + _scanned_dirs: Set[Path] = Field(default=None) + _pruned_paths: Set[Path] = Field(default=None) - def __init__(self, logger: types.ModuleType = logger): - super().__init__(logger) - self._model_set = set() - self.search_started = None - self.search_completed = None - self.model_found = None + def search_started(self): + self._models_found = set() + self._scanned_dirs = set() + self._pruned_paths = set() + if self.on_search_started: + self.on_search_started(self._directory) - def on_search_started(self): - self._model_set = set() - if self.search_started: - self.search_started() - - def on_model_found(self, model: Path): - if not self.model_found: - self._model_set.add(model) + def model_found(self, model: Path): + self.stats.models_found += 1 + if not self.on_model_found: + self.stats.models_filtered += 1 + self._models_found.add(model) return - if self.model_found(model): - self._model_set.add(model) + if self.on_model_found(model): + self.stats.models_filtered += 1 + self._models_found.add(model) - def on_search_completed(self): - if self.search_completed: - self.search_completed(self._model_set) + def search_completed(self): + if self.on_search_completed: + self.on_search_completed(self._models_found) - def list_models(self, directories: List[Union[Path, str]]) -> List[Path]: - """Return list of models found""" - self.search(directories) - return list(self._model_set) + def search(self, directory: Union[Path, str]) -> Set[Path]: + self._directory = directory + self.stats = SearchStats() # zero out + self.search_started() # This will initialize _models_found to empty + self._walk_directory(directory) + self.search_completed() + return self._models_found + + def _walk_directory(self, path: Union[Path, str]): + for root, dirs, files in os.walk(path, followlinks=True): + # don't descend into directories that start with a "." + # to avoid the Mac .DS_STORE issue. + if str(Path(root).name).startswith("."): + self._pruned_paths.add(root) + if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + continue + + self.stats.items_scanned += len(dirs) + len(files) + for d in dirs: + path = Path(root) / d + if path.parent in self._scanned_dirs: + self._scanned_dirs.add(path) + continue + if any( + [ + (path / x).exists() + for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"] + ] + ): + self._scanned_dirs.add(path) + try: + self.model_found(path) + except Exception as e: + self.logger.warning(str(e)) + + for f in files: + path = Path(root) / f + if path.parent in self._scanned_dirs: + continue + if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: + try: + self.model_found(path) + except Exception as e: + self.logger.warning(str(e))