From bd5b43c00d44feb5c1c3ee7bd3286e76ee5a28a6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 9 Mar 2024 20:55:06 +1100 Subject: [PATCH] tidy(mm): ModelSearch cleanup - No need for it to by a pydantic model. Just a class now. - Remove ABC, it made it hard to understand what was going on as attributes were spread across the ABC and implementation. Also, there is no other implementation. - Add tests --- invokeai/backend/model_manager/search.py | 125 +++++++------------- tests/test_model_search.py | 142 +++++++++++++++++++++++ 2 files changed, 187 insertions(+), 80 deletions(-) create mode 100644 tests/test_model_search.py diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 77eda3fddc..86fbbda6ca 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -21,104 +21,69 @@ Example usage: """ import os -from abc import ABC, abstractmethod +from dataclasses import dataclass from logging import Logger from pathlib import Path from typing import Callable, Optional, Set, Union -from pydantic import BaseModel, Field - from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger default_logger: Logger = InvokeAILogger.get_logger() -class SearchStats(BaseModel): - items_scanned: int = 0 - models_found: int = 0 - models_filtered: int = 0 +@dataclass +class SearchStats: + """Statistics about the search. - -class ModelSearchBase(ABC, BaseModel): + Attributes: + items_scanned: number of items scanned + models_found: number of models found + models_filtered: number of models that passed the filter """ - Abstract directory traversal model search class + + items_scanned = 0 + models_found = 0 + models_filtered = 0 + + +class ModelSearch: + """Searches a directory tree for models, using a callback to filter the results. Usage: - search = ModelSearchBase( - on_search_started = search_started_callback, - on_search_completed = search_completed_callback, - on_model_found = model_found_callback, - ) - models_found = search.search('/path/to/directory') + search = ModelSearch() + search.model_found = lambda path : 'anime' in path.as_posix() + found = search.list_models(['/tmp/models1','/tmp/models2']) + # returns all models that have 'anime' in the path """ - # fmt: off - on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221 - on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 - on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 - stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 - logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221 - # fmt: on + def __init__( + self, + stats: Optional[SearchStats] = None, + logger: Optional[Logger] = None, + on_search_started: Optional[Callable[[Path], None]] = None, + on_model_found: Optional[Callable[[Path], bool]] = None, + on_search_completed: Optional[Callable[[Set[Path]], None]] = None, + config: Optional[InvokeAIAppConfig] = None, + ) -> None: + """Create a new ModelSearch object. - class Config: - arbitrary_types_allowed = True - - @abstractmethod - def search_started(self) -> None: + Args: + stats: SearchStats object to hold statistics about the search + logger: Logger object to use for logging + on_search_started: callback to be invoked when the search starts + on_model_found: callback to be invoked when a model is found. The callback should return True if the model + should be included in the results. + on_search_completed: callback to be invoked when the search is completed + config: configuration object """ - Called before the scan starts. - - Passes the root search directory to the Callable `on_search_started`. - """ - pass - - @abstractmethod - def model_found(self, model: Path) -> None: - """ - Called when a model is found during search. - - :param model: Model to process - could be a directory or checkpoint. - - Passes the model's Path to the Callable `on_model_found`. - This Callable receives the path to the model and returns a boolean - to indicate whether the model should be returned in the search - results. - """ - pass - - @abstractmethod - def search_completed(self) -> None: - """ - Called before the scan starts. - - Passes the Set of found model Paths to the Callable `on_search_completed`. - """ - pass - - @abstractmethod - def search(self, directory: Union[Path, str]) -> Set[Path]: - """ - Recursively search for models in `directory` and return a set of model paths. - - If provided, the `on_search_started`, `on_model_found` and `on_search_completed` - Callables will be invoked during the search. - """ - pass - - -class ModelSearch(ModelSearchBase): - """ - Implementation of ModelSearch with callbacks. - Usage: - search = ModelSearch() - search.model_found = lambda path : 'anime' in path.as_posix() - found = search.list_models(['/tmp/models1','/tmp/models2']) - # returns all models that have 'anime' in the path - """ - - models_found: Set[Path] = Field(default_factory=set) - config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() + self.stats = stats or SearchStats() + self.logger = logger or default_logger + self.on_search_started = on_search_started + self.on_model_found = on_model_found + self.on_search_completed = on_search_completed + self.models_found: set[Path] = set() + self.config = config or InvokeAIAppConfig.get_config() def search_started(self) -> None: self.models_found = set() diff --git a/tests/test_model_search.py b/tests/test_model_search.py new file mode 100644 index 0000000000..a6470c6620 --- /dev/null +++ b/tests/test_model_search.py @@ -0,0 +1,142 @@ +from pathlib import Path + +import pytest + +from invokeai.backend.model_manager.search import ModelSearch + + +@pytest.fixture +def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]: + search = ModelSearch() + return search, tmp_path + + +def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]): + search, tmp_path = model_search + on_search_started_called_with: Path | None = None + + def on_search_started_callback(path: Path) -> None: + nonlocal on_search_started_called_with + on_search_started_called_with = path + + search.on_search_started = on_search_started_callback + search.search(tmp_path) + + assert on_search_started_called_with == tmp_path + + +def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]): + search, tmp_path = model_search + on_search_completed_called_with: set[Path] | None = None + file1 = tmp_path / "file1.ckpt" + with open(file1, "w") as f: + f.write("") + + def on_search_completed_callback(models: set[Path]) -> None: + nonlocal on_search_completed_called_with + on_search_completed_called_with = models + + search.on_search_completed = on_search_completed_callback + expected = {file1} + found = search.search(tmp_path) + + assert found == expected + assert on_search_completed_called_with == expected + + +def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]): + search, tmp_path = model_search + on_model_found_called_with: set[Path] = set() + + file1 = tmp_path / "file1.ckpt" + file2 = tmp_path / "file2.ckpt" + file3 = tmp_path / "subfolder" / "file3.ckpt" + file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt" + file5 = tmp_path / "not_a_model_file.txt" + + file4.parent.mkdir(parents=True) + for file in [file1, file2, file3, file4, file5]: + with open(file, "w") as f: + f.write("") + + def on_model_found_callback(path: Path) -> bool: + on_model_found_called_with.add(path) + return True + + search.on_model_found = on_model_found_callback + + expected = {file1, file2, file3, file4} + found = search.search(tmp_path) + + assert on_model_found_called_with == expected + assert found == expected + assert search.stats.models_found == 4 + assert search.stats.models_filtered == 4 + + +def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]): + search, tmp_path = model_search + on_model_found_called_with: set[Path] = set() + + file1 = tmp_path / "file1.ckpt" + file2 = tmp_path / "file2.ckpt" # explicitly ignored + + for file in [file1, file2]: + with open(file, "w") as f: + f.write("") + + def on_model_found_callback(path: Path) -> bool: + if path == file2: + return False + on_model_found_called_with.add(path) + return True + + search.on_model_found = on_model_found_callback + + expected = {file1} + found = search.search(tmp_path) + + assert on_model_found_called_with == expected + assert found == expected + assert search.stats.models_filtered == 1 + assert search.stats.models_found == 2 + + +def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]): + search, tmp_path = model_search + on_model_found_called_with: set[Path] = set() + + diffusers_dir = tmp_path / "diffusers_dir" + diffusers_dir_entry_point = diffusers_dir / "model_index.json" + diffusers_dir.mkdir() + with open(diffusers_dir_entry_point, "w") as f: + f.write("") + + nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir" + nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json" + nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped + nested_diffusers_dir.mkdir(parents=True) + with open(nested_diffusers_dir_entry_point, "w") as f: + f.write("") + with open(nested_diffusers_dir_ignore_me_file, "w") as f: + f.write("") + + not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir" + not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json" + not_a_diffusers_dir.mkdir() + with open(not_a_diffusers_dir_entry_point, "w") as f: + f.write("") + + def on_model_found_callback(path: Path) -> bool: + on_model_found_called_with.add(path) + return True + + search.on_model_found = on_model_found_callback + + expected = {diffusers_dir, nested_diffusers_dir} + found = search.search(tmp_path) + + assert found == expected + assert on_model_found_called_with == expected + assert search.stats.models_found == 2 + assert search.stats.models_filtered == 2