feat(mm): improved model hash class

- Use memory view for hashlib algorithms (closer to python 3.11's filehash API in hashlib)
- Remove `sha1_fast` (realized it doesn't even hash the whole file, it just does the first block)
- Add support for custom file filters
- Update docstrings
- Update tests
This commit is contained in:
psychedelicious 2024-03-03 14:14:15 +11:00 committed by Ryan Dick
parent 93aed57e81
commit af24013bb8
2 changed files with 114 additions and 57 deletions

View File

@ -7,10 +7,11 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') >>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f' 'a8e693a126ea5b831c96064dc569956f'
""" """
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union from typing import Callable, Literal, Optional, Union
from blake3 import blake3 from blake3 import blake3
@ -19,7 +20,6 @@ MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[ ALGORITHM = Literal[
"md5", "md5",
"sha1", "sha1",
"sha1_fast",
"sha224", "sha224",
"sha256", "sha256",
"sha384", "sha384",
@ -40,7 +40,9 @@ class ModelHash:
""" """
Creates a hash of a model using a specified algorithm. Creates a hash of a model using a specified algorithm.
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3. Args:
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
If the model is a single file, it is hashed directly using the provided algorithm. If the model is a single file, it is hashed directly using the provided algorithm.
@ -51,45 +53,57 @@ class ModelHash:
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes. that directory hashes are never weaker than the file hashes.
Usage Usage:
```py ```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors") ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/") ModelHash("md5").hash("path/to/model/dir/")
``` ```
""" """
def __init__(self, algorithm: ALGORITHM = "blake3") -> None: def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
if algorithm == "blake3": if algorithm == "blake3":
self._hash_file = self._blake3 self._hash_file = self._blake3
elif algorithm == "sha1_fast":
self._hash_file = self._sha1_fast
elif algorithm in hashlib.algorithms_available: elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm) self._hash_file = self._get_hashlib(algorithm)
else: else:
raise ValueError(f"Algorithm {algorithm} not available") raise ValueError(f"Algorithm {algorithm} not available")
def hash(self, model_location: Union[str, Path]) -> str: self._file_filter = file_filter or self._default_file_filter
"""
Return hexdigest string for model located at model_location.
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the def hash(self, model_path: Union[str, Path]) -> str:
"""
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3. directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model Args:
model_path: Path to the model
Returns:
str: Hexdigest of the hash of the model
""" """
model_location = Path(model_location) model_path = Path(model_path)
if model_location.is_file(): if model_path.is_file():
return self._hash_file(model_location) return self._hash_file(model_path)
elif model_location.is_dir(): elif model_path.is_dir():
return self._hash_dir(model_location) return self._hash_dir(model_path)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_path}")
def _hash_dir(self, model_location: Path) -> str: def _hash_dir(self, dir: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest.""" """Compute the hash for all files in a directory and return a hexdigest.
model_component_paths = self._get_file_paths(model_location)
Args:
dir: Path to the directory
Returns:
str: Hexdigest of the hash of the directory
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = [] component_hashes: list[str] = []
for component in sorted(model_component_paths): for component in sorted(model_component_paths):
@ -102,43 +116,70 @@ class ModelHash:
composite_hasher.update(h.encode("utf-8")) composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest() return composite_hasher.hexdigest()
@classmethod @staticmethod
def _get_file_paths(cls, dir: Path) -> list[Path]: def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
"""Return a list of all model files in the directory.""" """Return a list of all model files in the directory.
Args:
model_path: Path to the model
file_filter: Function that takes a file name and returns True if the file should be included in the list.
Returns:
List of all model files in the directory
"""
files: list[Path] = [] files: list[Path] = []
for root, _dirs, _files in os.walk(dir): for root, _dirs, _files in os.walk(model_path):
for file in _files: for file in _files:
if file.endswith(MODEL_FILE_EXTENSIONS): if file_filter(file):
files.append(Path(root, file)) files.append(Path(root, file))
return files return files
@staticmethod @staticmethod
def _blake3(file_path: Path) -> str: def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3""" """Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3(max_threads=blake3.AUTO) file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path) file_hasher.update_mmap(file_path)
return file_hasher.hexdigest() return file_hasher.hexdigest()
@staticmethod @staticmethod
def _sha1_fast(file_path: Path) -> str: def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Hashes a file using SHA1, but with a block size of 2**16. """Factory function that returns a function to hash a file with the given algorithm.
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
The algorithm is, however, very fast.""" Args:
BLOCK_SIZE = 2**16 algorithm: Hashing algorithm to use
file_hash = hashlib.sha1()
with open(file_path, "rb") as f: Returns:
data = f.read(BLOCK_SIZE) A function that hashes a file using the given algorithm
file_hash.update(data) """
return file_hash.hexdigest()
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
return hashlib_hasher
@staticmethod @staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]: def _default_file_filter(file_path: str) -> bool:
"""Hashes a file using a hashlib algorithm""" """A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
def hasher(file_path: Path) -> str: Args:
file_hasher = hashlib.new(algorithm) file_path: Path to the file
with open(file_path, "rb") as f:
file_hasher.update(f.read())
return file_hasher.hexdigest()
return hasher Returns:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)

View File

@ -6,7 +6,7 @@ from typing import Iterable
import pytest import pytest
from blake3 import blake3 from blake3 import blake3
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
test_cases: list[tuple[ALGORITHM, str]] = [ test_cases: list[tuple[ALGORITHM, str]] = [
("md5", "a0cd925fc063f98dbf029eee315060c3"), ("md5", "a0cd925fc063f98dbf029eee315060c3"),
@ -57,24 +57,40 @@ def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
def test_model_hash_filters_out_non_model_files(tmp_path: Path): def test_model_hash_filters_out_non_model_files(tmp_path: Path):
model_files = { model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
}
for i, f in enumerate(model_files): for i, f in enumerate(model_files):
f.write_text(f"data{i}") f.write_text(f"data{i}")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files) assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
# Add file that should be ignored - hash should not change # Add file that should be ignored - hash should not change
file = tmp_path / "test.icecream" file = tmp_path / "test.icecream"
file.write_text("data") file.write_text("data")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files) assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
# Add file that should not be ignored - hash should change # Add file that should not be ignored - hash should change
file = tmp_path / "test.bin" file = tmp_path / "test.bin"
file.write_text("more data") file.write_text("more data")
model_files.add(file) model_files.add(file)
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files) assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
def test_model_hash_uses_custom_filter(tmp_path: Path):
model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}
for i, f in enumerate(model_files):
f.write_text(f"data{i}")
def file_filter(file_path: str) -> bool:
return file_path.endswith(".pickme")
assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}