tidy(mm): ModelSearch cleanup

- No need for it to by a pydantic model. Just a class now.
- Remove ABC, it made it hard to understand what was going on as attributes were spread across the ABC and implementation. Also, there is no other implementation.
- Add tests
This commit is contained in:
psychedelicious 2024-03-09 20:55:06 +11:00
parent 631e789195
commit bd5b43c00d
2 changed files with 187 additions and 80 deletions

View File

@ -21,95 +21,35 @@ Example usage:
""" """
import os import os
from abc import ABC, abstractmethod from dataclasses import dataclass
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Set, Union from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger() default_logger: Logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel): @dataclass
items_scanned: int = 0 class SearchStats:
models_found: int = 0 """Statistics about the search.
models_filtered: int = 0
Attributes:
class ModelSearchBase(ABC, BaseModel): items_scanned: number of items scanned
""" models_found: number of models found
Abstract directory traversal model search class models_filtered: number of models that passed the filter
Usage:
search = ModelSearchBase(
on_search_started = search_started_callback,
on_search_completed = search_completed_callback,
on_model_found = model_found_callback,
)
models_found = search.search('/path/to/directory')
""" """
# fmt: off items_scanned = 0
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221 models_found = 0
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 models_filtered = 0
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
arbitrary_types_allowed = True
@abstractmethod
def search_started(self) -> None:
"""
Called before the scan starts.
Passes the root search directory to the Callable `on_search_started`.
"""
pass
@abstractmethod
def model_found(self, model: Path) -> None:
"""
Called when a model is found during search.
:param model: Model to process - could be a directory or checkpoint.
Passes the model's Path to the Callable `on_model_found`.
This Callable receives the path to the model and returns a boolean
to indicate whether the model should be returned in the search
results.
"""
pass
@abstractmethod
def search_completed(self) -> None:
"""
Called before the scan starts.
Passes the Set of found model Paths to the Callable `on_search_completed`.
"""
pass
@abstractmethod
def search(self, directory: Union[Path, str]) -> Set[Path]:
"""
Recursively search for models in `directory` and return a set of model paths.
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
Callables will be invoked during the search.
"""
pass
class ModelSearch(ModelSearchBase): class ModelSearch:
""" """Searches a directory tree for models, using a callback to filter the results.
Implementation of ModelSearch with callbacks.
Usage: Usage:
search = ModelSearch() search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix() search.model_found = lambda path : 'anime' in path.as_posix()
@ -117,8 +57,33 @@ class ModelSearch(ModelSearchBase):
# returns all models that have 'anime' in the path # returns all models that have 'anime' in the path
""" """
models_found: Set[Path] = Field(default_factory=set) def __init__(
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() self,
stats: Optional[SearchStats] = None,
logger: Optional[Logger] = None,
on_search_started: Optional[Callable[[Path], None]] = None,
on_model_found: Optional[Callable[[Path], bool]] = None,
on_search_completed: Optional[Callable[[Set[Path]], None]] = None,
config: Optional[InvokeAIAppConfig] = None,
) -> None:
"""Create a new ModelSearch object.
Args:
stats: SearchStats object to hold statistics about the search
logger: Logger object to use for logging
on_search_started: callback to be invoked when the search starts
on_model_found: callback to be invoked when a model is found. The callback should return True if the model
should be included in the results.
on_search_completed: callback to be invoked when the search is completed
config: configuration object
"""
self.stats = stats or SearchStats()
self.logger = logger or default_logger
self.on_search_started = on_search_started
self.on_model_found = on_model_found
self.on_search_completed = on_search_completed
self.models_found: set[Path] = set()
self.config = config or InvokeAIAppConfig.get_config()
def search_started(self) -> None: def search_started(self) -> None:
self.models_found = set() self.models_found = set()

142
tests/test_model_search.py Normal file
View File

@ -0,0 +1,142 @@
from pathlib import Path
import pytest
from invokeai.backend.model_manager.search import ModelSearch
@pytest.fixture
def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]:
search = ModelSearch()
return search, tmp_path
def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_search_started_called_with: Path | None = None
def on_search_started_callback(path: Path) -> None:
nonlocal on_search_started_called_with
on_search_started_called_with = path
search.on_search_started = on_search_started_callback
search.search(tmp_path)
assert on_search_started_called_with == tmp_path
def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_search_completed_called_with: set[Path] | None = None
file1 = tmp_path / "file1.ckpt"
with open(file1, "w") as f:
f.write("")
def on_search_completed_callback(models: set[Path]) -> None:
nonlocal on_search_completed_called_with
on_search_completed_called_with = models
search.on_search_completed = on_search_completed_callback
expected = {file1}
found = search.search(tmp_path)
assert found == expected
assert on_search_completed_called_with == expected
def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
file1 = tmp_path / "file1.ckpt"
file2 = tmp_path / "file2.ckpt"
file3 = tmp_path / "subfolder" / "file3.ckpt"
file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt"
file5 = tmp_path / "not_a_model_file.txt"
file4.parent.mkdir(parents=True)
for file in [file1, file2, file3, file4, file5]:
with open(file, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {file1, file2, file3, file4}
found = search.search(tmp_path)
assert on_model_found_called_with == expected
assert found == expected
assert search.stats.models_found == 4
assert search.stats.models_filtered == 4
def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
file1 = tmp_path / "file1.ckpt"
file2 = tmp_path / "file2.ckpt" # explicitly ignored
for file in [file1, file2]:
with open(file, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
if path == file2:
return False
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {file1}
found = search.search(tmp_path)
assert on_model_found_called_with == expected
assert found == expected
assert search.stats.models_filtered == 1
assert search.stats.models_found == 2
def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
diffusers_dir = tmp_path / "diffusers_dir"
diffusers_dir_entry_point = diffusers_dir / "model_index.json"
diffusers_dir.mkdir()
with open(diffusers_dir_entry_point, "w") as f:
f.write("")
nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir"
nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json"
nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped
nested_diffusers_dir.mkdir(parents=True)
with open(nested_diffusers_dir_entry_point, "w") as f:
f.write("")
with open(nested_diffusers_dir_ignore_me_file, "w") as f:
f.write("")
not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir"
not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json"
not_a_diffusers_dir.mkdir()
with open(not_a_diffusers_dir_entry_point, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {diffusers_dir, nested_diffusers_dir}
found = search.search(tmp_path)
assert found == expected
assert on_model_found_called_with == expected
assert search.stats.models_found == 2
assert search.stats.models_filtered == 2