InvokeAI/tests/test_model_search.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

143 lines
4.5 KiB
Python
Raw Permalink Normal View History

from pathlib import Path
import pytest
from invokeai.backend.model_manager.search import ModelSearch
@pytest.fixture
def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]:
search = ModelSearch()
return search, tmp_path
def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_search_started_called_with: Path | None = None
def on_search_started_callback(path: Path) -> None:
nonlocal on_search_started_called_with
on_search_started_called_with = path
search.on_search_started = on_search_started_callback
search.search(tmp_path)
assert on_search_started_called_with == tmp_path
def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_search_completed_called_with: set[Path] | None = None
file1 = tmp_path / "file1.ckpt"
with open(file1, "w") as f:
f.write("")
def on_search_completed_callback(models: set[Path]) -> None:
nonlocal on_search_completed_called_with
on_search_completed_called_with = models
search.on_search_completed = on_search_completed_callback
expected = {file1}
found = search.search(tmp_path)
assert found == expected
assert on_search_completed_called_with == expected
def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
file1 = tmp_path / "file1.ckpt"
file2 = tmp_path / "file2.ckpt"
file3 = tmp_path / "subfolder" / "file3.ckpt"
file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt"
file5 = tmp_path / "not_a_model_file.txt"
file4.parent.mkdir(parents=True)
for file in [file1, file2, file3, file4, file5]:
with open(file, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {file1, file2, file3, file4}
found = search.search(tmp_path)
assert on_model_found_called_with == expected
assert found == expected
assert search.stats.models_found == 4
assert search.stats.models_filtered == 4
def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
file1 = tmp_path / "file1.ckpt"
file2 = tmp_path / "file2.ckpt" # explicitly ignored
for file in [file1, file2]:
with open(file, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
if path == file2:
return False
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {file1}
found = search.search(tmp_path)
assert on_model_found_called_with == expected
assert found == expected
assert search.stats.models_filtered == 1
assert search.stats.models_found == 2
def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]):
search, tmp_path = model_search
on_model_found_called_with: set[Path] = set()
diffusers_dir = tmp_path / "diffusers_dir"
diffusers_dir_entry_point = diffusers_dir / "model_index.json"
diffusers_dir.mkdir()
with open(diffusers_dir_entry_point, "w") as f:
f.write("")
nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir"
nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json"
nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped
nested_diffusers_dir.mkdir(parents=True)
with open(nested_diffusers_dir_entry_point, "w") as f:
f.write("")
with open(nested_diffusers_dir_ignore_me_file, "w") as f:
f.write("")
not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir"
not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json"
not_a_diffusers_dir.mkdir()
with open(not_a_diffusers_dir_entry_point, "w") as f:
f.write("")
def on_model_found_callback(path: Path) -> bool:
on_model_found_called_with.add(path)
return True
search.on_model_found = on_model_found_callback
expected = {diffusers_dir, nested_diffusers_dir}
found = search.search(tmp_path)
assert found == expected
assert on_model_found_called_with == expected
assert search.stats.models_found == 2
assert search.stats.models_filtered == 2