feat(mm): improved model hash class

- Use memory view for hashlib algorithms (closer to python 3.11's filehash API in hashlib)
- Remove `sha1_fast` (realized it doesn't even hash the whole file, it just does the first block)
- Add support for custom file filters
- Update docstrings
- Update tests
This commit is contained in:
psychedelicious 2024-03-03 14:14:15 +11:00 committed by Ryan Dick
parent 93aed57e81
commit af24013bb8
2 changed files with 114 additions and 57 deletions

View File

@ -7,10 +7,11 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Callable, Literal, Union
from typing import Callable, Literal, Optional, Union
from blake3 import blake3
@ -19,7 +20,6 @@ MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha1_fast",
"sha224",
"sha256",
"sha384",
@ -40,7 +40,9 @@ class ModelHash:
"""
Creates a hash of a model using a specified algorithm.
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
Args:
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
If the model is a single file, it is hashed directly using the provided algorithm.
@ -51,45 +53,57 @@ class ModelHash:
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "sha1_fast":
self._hash_file = self._sha1_fast
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
def hash(self, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
self._file_filter = file_filter or self._default_file_filter
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
def hash(self, model_path: Union[str, Path]) -> str:
"""
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model
Args:
model_path: Path to the model
Returns:
str: Hexdigest of the hash of the model
"""
model_location = Path(model_location)
if model_location.is_file():
return self._hash_file(model_location)
elif model_location.is_dir():
return self._hash_dir(model_location)
model_path = Path(model_path)
if model_path.is_file():
return self._hash_file(model_path)
elif model_path.is_dir():
return self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
raise OSError(f"Not a valid file or directory: {model_path}")
def _hash_dir(self, model_location: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest."""
model_component_paths = self._get_file_paths(model_location)
def _hash_dir(self, dir: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest.
Args:
dir: Path to the directory
Returns:
str: Hexdigest of the hash of the directory
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
@ -102,43 +116,70 @@ class ModelHash:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@classmethod
def _get_file_paths(cls, dir: Path) -> list[Path]:
"""Return a list of all model files in the directory."""
@staticmethod
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
"""Return a list of all model files in the directory.
Args:
model_path: Path to the model
file_filter: Function that takes a file name and returns True if the file should be included in the list.
Returns:
List of all model files in the directory
"""
files: list[Path] = []
for root, _dirs, _files in os.walk(dir):
for root, _dirs, _files in os.walk(model_path):
for file in _files:
if file.endswith(MODEL_FILE_EXTENSIONS):
if file_filter(file):
files.append(Path(root, file))
return files
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3"""
"""Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _sha1_fast(file_path: Path) -> str:
"""Hashes a file using SHA1, but with a block size of 2**16.
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
The algorithm is, however, very fast."""
BLOCK_SIZE = 2**16
file_hash = hashlib.sha1()
with open(file_path, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
algorithm: Hashing algorithm to use
Returns:
A function that hashes a file using the given algorithm
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
return hashlib_hasher
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Hashes a file using a hashlib algorithm"""
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
def hasher(file_path: Path) -> str:
file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
file_hasher.update(f.read())
return file_hasher.hexdigest()
Args:
file_path: Path to the file
return hasher
Returns:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)

View File

@ -6,7 +6,7 @@ from typing import Iterable
import pytest
from blake3 import blake3
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
test_cases: list[tuple[ALGORITHM, str]] = [
("md5", "a0cd925fc063f98dbf029eee315060c3"),
@ -57,24 +57,40 @@ def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
model_files = {
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
}
model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}
for i, f in enumerate(model_files):
f.write_text(f"data{i}")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
# Add file that should be ignored - hash should not change
file = tmp_path / "test.icecream"
file.write_text("data")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
# Add file that should not be ignored - hash should change
file = tmp_path / "test.bin"
file.write_text("more data")
model_files.add(file)
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
model_files
)
def test_model_hash_uses_custom_filter(tmp_path: Path):
model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}
for i, f in enumerate(model_files):
f.write_text(f"data{i}")
def file_filter(file_path: str) -> bool:
return file_path.endswith(".pickme")
assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}