mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(mm): ModelSearch cleanup
- No need for it to by a pydantic model. Just a class now. - Remove ABC, it made it hard to understand what was going on as attributes were spread across the ABC and implementation. Also, there is no other implementation. - Add tests
This commit is contained in:
142
tests/test_model_search.py
Normal file
142
tests/test_model_search.py
Normal file
@ -0,0 +1,142 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]:
|
||||
search = ModelSearch()
|
||||
return search, tmp_path
|
||||
|
||||
|
||||
def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]):
|
||||
search, tmp_path = model_search
|
||||
on_search_started_called_with: Path | None = None
|
||||
|
||||
def on_search_started_callback(path: Path) -> None:
|
||||
nonlocal on_search_started_called_with
|
||||
on_search_started_called_with = path
|
||||
|
||||
search.on_search_started = on_search_started_callback
|
||||
search.search(tmp_path)
|
||||
|
||||
assert on_search_started_called_with == tmp_path
|
||||
|
||||
|
||||
def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]):
|
||||
search, tmp_path = model_search
|
||||
on_search_completed_called_with: set[Path] | None = None
|
||||
file1 = tmp_path / "file1.ckpt"
|
||||
with open(file1, "w") as f:
|
||||
f.write("")
|
||||
|
||||
def on_search_completed_callback(models: set[Path]) -> None:
|
||||
nonlocal on_search_completed_called_with
|
||||
on_search_completed_called_with = models
|
||||
|
||||
search.on_search_completed = on_search_completed_callback
|
||||
expected = {file1}
|
||||
found = search.search(tmp_path)
|
||||
|
||||
assert found == expected
|
||||
assert on_search_completed_called_with == expected
|
||||
|
||||
|
||||
def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]):
|
||||
search, tmp_path = model_search
|
||||
on_model_found_called_with: set[Path] = set()
|
||||
|
||||
file1 = tmp_path / "file1.ckpt"
|
||||
file2 = tmp_path / "file2.ckpt"
|
||||
file3 = tmp_path / "subfolder" / "file3.ckpt"
|
||||
file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt"
|
||||
file5 = tmp_path / "not_a_model_file.txt"
|
||||
|
||||
file4.parent.mkdir(parents=True)
|
||||
for file in [file1, file2, file3, file4, file5]:
|
||||
with open(file, "w") as f:
|
||||
f.write("")
|
||||
|
||||
def on_model_found_callback(path: Path) -> bool:
|
||||
on_model_found_called_with.add(path)
|
||||
return True
|
||||
|
||||
search.on_model_found = on_model_found_callback
|
||||
|
||||
expected = {file1, file2, file3, file4}
|
||||
found = search.search(tmp_path)
|
||||
|
||||
assert on_model_found_called_with == expected
|
||||
assert found == expected
|
||||
assert search.stats.models_found == 4
|
||||
assert search.stats.models_filtered == 4
|
||||
|
||||
|
||||
def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]):
|
||||
search, tmp_path = model_search
|
||||
on_model_found_called_with: set[Path] = set()
|
||||
|
||||
file1 = tmp_path / "file1.ckpt"
|
||||
file2 = tmp_path / "file2.ckpt" # explicitly ignored
|
||||
|
||||
for file in [file1, file2]:
|
||||
with open(file, "w") as f:
|
||||
f.write("")
|
||||
|
||||
def on_model_found_callback(path: Path) -> bool:
|
||||
if path == file2:
|
||||
return False
|
||||
on_model_found_called_with.add(path)
|
||||
return True
|
||||
|
||||
search.on_model_found = on_model_found_callback
|
||||
|
||||
expected = {file1}
|
||||
found = search.search(tmp_path)
|
||||
|
||||
assert on_model_found_called_with == expected
|
||||
assert found == expected
|
||||
assert search.stats.models_filtered == 1
|
||||
assert search.stats.models_found == 2
|
||||
|
||||
|
||||
def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]):
|
||||
search, tmp_path = model_search
|
||||
on_model_found_called_with: set[Path] = set()
|
||||
|
||||
diffusers_dir = tmp_path / "diffusers_dir"
|
||||
diffusers_dir_entry_point = diffusers_dir / "model_index.json"
|
||||
diffusers_dir.mkdir()
|
||||
with open(diffusers_dir_entry_point, "w") as f:
|
||||
f.write("")
|
||||
|
||||
nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir"
|
||||
nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json"
|
||||
nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped
|
||||
nested_diffusers_dir.mkdir(parents=True)
|
||||
with open(nested_diffusers_dir_entry_point, "w") as f:
|
||||
f.write("")
|
||||
with open(nested_diffusers_dir_ignore_me_file, "w") as f:
|
||||
f.write("")
|
||||
|
||||
not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir"
|
||||
not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json"
|
||||
not_a_diffusers_dir.mkdir()
|
||||
with open(not_a_diffusers_dir_entry_point, "w") as f:
|
||||
f.write("")
|
||||
|
||||
def on_model_found_callback(path: Path) -> bool:
|
||||
on_model_found_called_with.add(path)
|
||||
return True
|
||||
|
||||
search.on_model_found = on_model_found_callback
|
||||
|
||||
expected = {diffusers_dir, nested_diffusers_dir}
|
||||
found = search.search(tmp_path)
|
||||
|
||||
assert found == expected
|
||||
assert on_model_found_called_with == expected
|
||||
assert search.stats.models_found == 2
|
||||
assert search.stats.models_filtered == 2
|
Reference in New Issue
Block a user