mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): improved model hash class
- Use memory view for hashlib algorithms (closer to python 3.11's filehash API in hashlib) - Remove `sha1_fast` (realized it doesn't even hash the whole file, it just does the first block) - Add support for custom file filters - Update docstrings - Update tests
This commit is contained in:
@ -6,7 +6,7 @@ from typing import Iterable
|
||||
import pytest
|
||||
from blake3 import blake3
|
||||
|
||||
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
|
||||
from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
|
||||
|
||||
test_cases: list[tuple[ALGORITHM, str]] = [
|
||||
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
||||
@ -57,24 +57,40 @@ def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
|
||||
|
||||
|
||||
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
|
||||
model_files = {
|
||||
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
|
||||
}
|
||||
model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}
|
||||
|
||||
for i, f in enumerate(model_files):
|
||||
f.write_text(f"data{i}")
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||
model_files
|
||||
)
|
||||
|
||||
# Add file that should be ignored - hash should not change
|
||||
file = tmp_path / "test.icecream"
|
||||
file.write_text("data")
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||
model_files
|
||||
)
|
||||
|
||||
# Add file that should not be ignored - hash should change
|
||||
file = tmp_path / "test.bin"
|
||||
file.write_text("more data")
|
||||
model_files.add(file)
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||
model_files
|
||||
)
|
||||
|
||||
|
||||
def test_model_hash_uses_custom_filter(tmp_path: Path):
|
||||
model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}
|
||||
|
||||
for i, f in enumerate(model_files):
|
||||
f.write_text(f"data{i}")
|
||||
|
||||
def file_filter(file_path: str) -> bool:
|
||||
return file_path.endswith(".pickme")
|
||||
|
||||
assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}
|
||||
|
Reference in New Issue
Block a user