from pathlib import Path import pytest from invokeai.backend.model_manager.search import ModelSearch @pytest.fixture def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]: search = ModelSearch() return search, tmp_path def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]): search, tmp_path = model_search on_search_started_called_with: Path | None = None def on_search_started_callback(path: Path) -> None: nonlocal on_search_started_called_with on_search_started_called_with = path search.on_search_started = on_search_started_callback search.search(tmp_path) assert on_search_started_called_with == tmp_path def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]): search, tmp_path = model_search on_search_completed_called_with: set[Path] | None = None file1 = tmp_path / "file1.ckpt" with open(file1, "w") as f: f.write("") def on_search_completed_callback(models: set[Path]) -> None: nonlocal on_search_completed_called_with on_search_completed_called_with = models search.on_search_completed = on_search_completed_callback expected = {file1} found = search.search(tmp_path) assert found == expected assert on_search_completed_called_with == expected def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]): search, tmp_path = model_search on_model_found_called_with: set[Path] = set() file1 = tmp_path / "file1.ckpt" file2 = tmp_path / "file2.ckpt" file3 = tmp_path / "subfolder" / "file3.ckpt" file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt" file5 = tmp_path / "not_a_model_file.txt" file4.parent.mkdir(parents=True) for file in [file1, file2, file3, file4, file5]: with open(file, "w") as f: f.write("") def on_model_found_callback(path: Path) -> bool: on_model_found_called_with.add(path) return True search.on_model_found = on_model_found_callback expected = {file1, file2, file3, file4} found = search.search(tmp_path) assert on_model_found_called_with == expected assert found == expected assert search.stats.models_found == 4 assert search.stats.models_filtered == 4 def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]): search, tmp_path = model_search on_model_found_called_with: set[Path] = set() file1 = tmp_path / "file1.ckpt" file2 = tmp_path / "file2.ckpt" # explicitly ignored for file in [file1, file2]: with open(file, "w") as f: f.write("") def on_model_found_callback(path: Path) -> bool: if path == file2: return False on_model_found_called_with.add(path) return True search.on_model_found = on_model_found_callback expected = {file1} found = search.search(tmp_path) assert on_model_found_called_with == expected assert found == expected assert search.stats.models_filtered == 1 assert search.stats.models_found == 2 def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]): search, tmp_path = model_search on_model_found_called_with: set[Path] = set() diffusers_dir = tmp_path / "diffusers_dir" diffusers_dir_entry_point = diffusers_dir / "model_index.json" diffusers_dir.mkdir() with open(diffusers_dir_entry_point, "w") as f: f.write("") nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir" nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json" nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped nested_diffusers_dir.mkdir(parents=True) with open(nested_diffusers_dir_entry_point, "w") as f: f.write("") with open(nested_diffusers_dir_ignore_me_file, "w") as f: f.write("") not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir" not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json" not_a_diffusers_dir.mkdir() with open(not_a_diffusers_dir_entry_point, "w") as f: f.write("") def on_model_found_callback(path: Path) -> bool: on_model_found_called_with.add(path) return True search.on_model_found = on_model_found_callback expected = {diffusers_dir, nested_diffusers_dir} found = search.search(tmp_path) assert found == expected assert on_model_found_called_with == expected assert search.stats.models_found == 2 assert search.stats.models_filtered == 2