tidy(mm): ModelSearch cleanup

- No need for it to by a pydantic model. Just a class now.
- Remove ABC, it made it hard to understand what was going on as attributes were spread across the ABC and implementation. Also, there is no other implementation.
- Add tests
This commit is contained in:
psychedelicious 2024-03-09 20:55:06 +11:00
parent 631e789195
commit bd5b43c00d
2 changed files with 187 additions and 80 deletions

View File

@ -21,104 +21,69 @@ Example usage:
"""
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel):
items_scanned: int = 0
models_found: int = 0
models_filtered: int = 0
@dataclass
class SearchStats:
"""Statistics about the search.
class ModelSearchBase(ABC, BaseModel):
Attributes:
items_scanned: number of items scanned
models_found: number of models found
models_filtered: number of models that passed the filter
"""
Abstract directory traversal model search class
items_scanned = 0
models_found = 0
models_filtered = 0
class ModelSearch:
"""Searches a directory tree for models, using a callback to filter the results.
Usage:
search = ModelSearchBase(
on_search_started = search_started_callback,
on_search_completed = search_completed_callback,
on_model_found = model_found_callback,
)
models_found = search.search('/path/to/directory')
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
"""
# fmt: off
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
def __init__(
self,
stats: Optional[SearchStats] = None,
logger: Optional[Logger] = None,
on_search_started: Optional[Callable[[Path], None]] = None,
on_model_found: Optional[Callable[[Path], bool]] = None,
on_search_completed: Optional[Callable[[Set[Path]], None]] = None,
config: Optional[InvokeAIAppConfig] = None,
) -> None:
"""Create a new ModelSearch object.
class Config:
arbitrary_types_allowed = True
@abstractmethod
def search_started(self) -> None:
Args:
stats: SearchStats object to hold statistics about the search
logger: Logger object to use for logging
on_search_started: callback to be invoked when the search starts
on_model_found: callback to be invoked when a model is found. The callback should return True if the model
should be included in the results.
on_search_completed: callback to be invoked when the search is completed
config: configuration object
"""
Called before the scan starts.
Passes the root search directory to the Callable `on_search_started`.
"""
pass
@abstractmethod
def model_found(self, model: Path) -> None:
"""
Called when a model is found during search.
:param model: Model to process - could be a directory or checkpoint.
Passes the model's Path to the Callable `on_model_found`.
This Callable receives the path to the model and returns a boolean
to indicate whether the model should be returned in the search
results.
"""
pass
@abstractmethod
def search_completed(self) -> None:
"""
Called before the scan starts.
Passes the Set of found model Paths to the Callable `on_search_completed`.
"""
pass
@abstractmethod
def search(self, directory: Union[Path, str]) -> Set[Path]:
"""
Recursively search for models in `directory` and return a set of model paths.
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
Callables will be invoked during the search.
"""
pass
class ModelSearch(ModelSearchBase):
"""
Implementation of ModelSearch with callbacks.
Usage:
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
"""
models_found: Set[Path] = Field(default_factory=set)
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
self.stats = stats or SearchStats()
self.logger = logger or default_logger
self.on_search_started = on_search_started
self.on_model_found = on_model_found
self.on_search_completed = on_search_completed
self.models_found: set[Path] = set()
self.config = config or InvokeAIAppConfig.get_config()
def search_started(self) -> None:
self.models_found = set()

142
tests/test_model_search.py Normal file
View File

@ -0,0 +1,142 @@
from pathlib import Path
import pytest
from invokeai.backend.model_manager.search import ModelSearch
@pytest.fixture
def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]:
search = ModelSearch()
return search, tmp_path
def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_search_started_called_with: Path | None = None
def on_search_started_callback(path: Path) -> None:
nonlocal on_search_started_called_with
on_search_started_called_with = path
search.on_search_started = on_search_started_callback
search.search(tmp_path)
assert on_search_started_called_with == tmp_path
def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_search_completed_called_with: set[Path] | None = None
file1 = tmp_path / "file1.ckpt"
with open(file1, "w") as f:
f.write("")
def on_search_completed_callback(models: set[Path]) -> None:
nonlocal on_search_completed_called_with
on_search_completed_called_with = models
search.on_search_completed = on_search_completed_callback
expected = {file1}
found = search.search(tmp_path)
assert found == expected
assert on_search_completed_called_with == expected
def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
file1 = tmp_path / "file1.ckpt"
file2 = tmp_path / "file2.ckpt"
file3 = tmp_path / "subfolder" / "file3.ckpt"
file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt"
file5 = tmp_path / "not_a_model_file.txt"
file4.parent.mkdir(parents=True)
for file in [file1, file2, file3, file4, file5]:
with open(file, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {file1, file2, file3, file4}
found = search.search(tmp_path)
assert on_model_found_called_with == expected
assert found == expected
assert search.stats.models_found == 4
assert search.stats.models_filtered == 4
def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
file1 = tmp_path / "file1.ckpt"
file2 = tmp_path / "file2.ckpt" # explicitly ignored
for file in [file1, file2]:
with open(file, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
if path == file2:
return False
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {file1}
found = search.search(tmp_path)
assert on_model_found_called_with == expected
assert found == expected
assert search.stats.models_filtered == 1
assert search.stats.models_found == 2
def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
diffusers_dir = tmp_path / "diffusers_dir"
diffusers_dir_entry_point = diffusers_dir / "model_index.json"
diffusers_dir.mkdir()
with open(diffusers_dir_entry_point, "w") as f:
f.write("")
nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir"
nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json"
nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped
nested_diffusers_dir.mkdir(parents=True)
with open(nested_diffusers_dir_entry_point, "w") as f:
f.write("")
with open(nested_diffusers_dir_ignore_me_file, "w") as f:
f.write("")
not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir"
not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json"
not_a_diffusers_dir.mkdir()
with open(not_a_diffusers_dir_entry_point, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {diffusers_dir, nested_diffusers_dir}
found = search.search(tmp_path)
assert found == expected
assert on_model_found_called_with == expected
assert search.stats.models_found == 2
assert search.stats.models_filtered == 2