mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
143 lines
4.5 KiB
Python
143 lines
4.5 KiB
Python
|
from pathlib import Path
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from invokeai.backend.model_manager.search import ModelSearch
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]:
|
||
|
search = ModelSearch()
|
||
|
return search, tmp_path
|
||
|
|
||
|
|
||
|
def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]):
|
||
|
search, tmp_path = model_search
|
||
|
on_search_started_called_with: Path | None = None
|
||
|
|
||
|
def on_search_started_callback(path: Path) -> None:
|
||
|
nonlocal on_search_started_called_with
|
||
|
on_search_started_called_with = path
|
||
|
|
||
|
search.on_search_started = on_search_started_callback
|
||
|
search.search(tmp_path)
|
||
|
|
||
|
assert on_search_started_called_with == tmp_path
|
||
|
|
||
|
|
||
|
def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]):
|
||
|
search, tmp_path = model_search
|
||
|
on_search_completed_called_with: set[Path] | None = None
|
||
|
file1 = tmp_path / "file1.ckpt"
|
||
|
with open(file1, "w") as f:
|
||
|
f.write("")
|
||
|
|
||
|
def on_search_completed_callback(models: set[Path]) -> None:
|
||
|
nonlocal on_search_completed_called_with
|
||
|
on_search_completed_called_with = models
|
||
|
|
||
|
search.on_search_completed = on_search_completed_callback
|
||
|
expected = {file1}
|
||
|
found = search.search(tmp_path)
|
||
|
|
||
|
assert found == expected
|
||
|
assert on_search_completed_called_with == expected
|
||
|
|
||
|
|
||
|
def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]):
|
||
|
search, tmp_path = model_search
|
||
|
on_model_found_called_with: set[Path] = set()
|
||
|
|
||
|
file1 = tmp_path / "file1.ckpt"
|
||
|
file2 = tmp_path / "file2.ckpt"
|
||
|
file3 = tmp_path / "subfolder" / "file3.ckpt"
|
||
|
file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt"
|
||
|
file5 = tmp_path / "not_a_model_file.txt"
|
||
|
|
||
|
file4.parent.mkdir(parents=True)
|
||
|
for file in [file1, file2, file3, file4, file5]:
|
||
|
with open(file, "w") as f:
|
||
|
f.write("")
|
||
|
|
||
|
def on_model_found_callback(path: Path) -> bool:
|
||
|
on_model_found_called_with.add(path)
|
||
|
return True
|
||
|
|
||
|
search.on_model_found = on_model_found_callback
|
||
|
|
||
|
expected = {file1, file2, file3, file4}
|
||
|
found = search.search(tmp_path)
|
||
|
|
||
|
assert on_model_found_called_with == expected
|
||
|
assert found == expected
|
||
|
assert search.stats.models_found == 4
|
||
|
assert search.stats.models_filtered == 4
|
||
|
|
||
|
|
||
|
def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]):
|
||
|
search, tmp_path = model_search
|
||
|
on_model_found_called_with: set[Path] = set()
|
||
|
|
||
|
file1 = tmp_path / "file1.ckpt"
|
||
|
file2 = tmp_path / "file2.ckpt" # explicitly ignored
|
||
|
|
||
|
for file in [file1, file2]:
|
||
|
with open(file, "w") as f:
|
||
|
f.write("")
|
||
|
|
||
|
def on_model_found_callback(path: Path) -> bool:
|
||
|
if path == file2:
|
||
|
return False
|
||
|
on_model_found_called_with.add(path)
|
||
|
return True
|
||
|
|
||
|
search.on_model_found = on_model_found_callback
|
||
|
|
||
|
expected = {file1}
|
||
|
found = search.search(tmp_path)
|
||
|
|
||
|
assert on_model_found_called_with == expected
|
||
|
assert found == expected
|
||
|
assert search.stats.models_filtered == 1
|
||
|
assert search.stats.models_found == 2
|
||
|
|
||
|
|
||
|
def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]):
|
||
|
search, tmp_path = model_search
|
||
|
on_model_found_called_with: set[Path] = set()
|
||
|
|
||
|
diffusers_dir = tmp_path / "diffusers_dir"
|
||
|
diffusers_dir_entry_point = diffusers_dir / "model_index.json"
|
||
|
diffusers_dir.mkdir()
|
||
|
with open(diffusers_dir_entry_point, "w") as f:
|
||
|
f.write("")
|
||
|
|
||
|
nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir"
|
||
|
nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json"
|
||
|
nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped
|
||
|
nested_diffusers_dir.mkdir(parents=True)
|
||
|
with open(nested_diffusers_dir_entry_point, "w") as f:
|
||
|
f.write("")
|
||
|
with open(nested_diffusers_dir_ignore_me_file, "w") as f:
|
||
|
f.write("")
|
||
|
|
||
|
not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir"
|
||
|
not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json"
|
||
|
not_a_diffusers_dir.mkdir()
|
||
|
with open(not_a_diffusers_dir_entry_point, "w") as f:
|
||
|
f.write("")
|
||
|
|
||
|
def on_model_found_callback(path: Path) -> bool:
|
||
|
on_model_found_called_with.add(path)
|
||
|
return True
|
||
|
|
||
|
search.on_model_found = on_model_found_callback
|
||
|
|
||
|
expected = {diffusers_dir, nested_diffusers_dir}
|
||
|
found = search.search(tmp_path)
|
||
|
|
||
|
assert found == expected
|
||
|
assert on_model_found_called_with == expected
|
||
|
assert search.stats.models_found == 2
|
||
|
assert search.stats.models_filtered == 2
|