mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): improved model hash class
- Use memory view for hashlib algorithms (closer to python 3.11's filehash API in hashlib) - Remove `sha1_fast` (realized it doesn't even hash the whole file, it just does the first block) - Add support for custom file filters - Update docstrings - Update tests
This commit is contained in:
parent
93aed57e81
commit
af24013bb8
@ -7,10 +7,11 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
|||||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||||
'a8e693a126ea5b831c96064dc569956f'
|
'a8e693a126ea5b831c96064dc569956f'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Literal, Union
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
|
|
||||||
@ -19,7 +20,6 @@ MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
|||||||
ALGORITHM = Literal[
|
ALGORITHM = Literal[
|
||||||
"md5",
|
"md5",
|
||||||
"sha1",
|
"sha1",
|
||||||
"sha1_fast",
|
|
||||||
"sha224",
|
"sha224",
|
||||||
"sha256",
|
"sha256",
|
||||||
"sha384",
|
"sha384",
|
||||||
@ -40,7 +40,9 @@ class ModelHash:
|
|||||||
"""
|
"""
|
||||||
Creates a hash of a model using a specified algorithm.
|
Creates a hash of a model using a specified algorithm.
|
||||||
|
|
||||||
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
Args:
|
||||||
|
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||||
|
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||||
|
|
||||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||||
|
|
||||||
@ -51,45 +53,57 @@ class ModelHash:
|
|||||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||||
that directory hashes are never weaker than the file hashes.
|
that directory hashes are never weaker than the file hashes.
|
||||||
|
|
||||||
Usage
|
Usage:
|
||||||
|
```py
|
||||||
```py
|
# BLAKE3 hash
|
||||||
ModelHash().hash("path/to/some/model.safetensors")
|
ModelHash().hash("path/to/some/model.safetensors")
|
||||||
ModelHash("md5").hash("path/to/model/dir/")
|
# MD5
|
||||||
```
|
ModelHash("md5").hash("path/to/model/dir/")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
|
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||||
if algorithm == "blake3":
|
if algorithm == "blake3":
|
||||||
self._hash_file = self._blake3
|
self._hash_file = self._blake3
|
||||||
elif algorithm == "sha1_fast":
|
|
||||||
self._hash_file = self._sha1_fast
|
|
||||||
elif algorithm in hashlib.algorithms_available:
|
elif algorithm in hashlib.algorithms_available:
|
||||||
self._hash_file = self._get_hashlib(algorithm)
|
self._hash_file = self._get_hashlib(algorithm)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Algorithm {algorithm} not available")
|
raise ValueError(f"Algorithm {algorithm} not available")
|
||||||
|
|
||||||
def hash(self, model_location: Union[str, Path]) -> str:
|
self._file_filter = file_filter or self._default_file_filter
|
||||||
"""
|
|
||||||
Return hexdigest string for model located at model_location.
|
|
||||||
|
|
||||||
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
|
def hash(self, model_path: Union[str, Path]) -> str:
|
||||||
|
"""
|
||||||
|
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||||
|
|
||||||
|
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||||
directory. The final composite hash is always computed using BLAKE3.
|
directory. The final composite hash is always computed using BLAKE3.
|
||||||
|
|
||||||
:param model_location: Path to the model
|
Args:
|
||||||
|
model_path: Path to the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Hexdigest of the hash of the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_location = Path(model_location)
|
model_path = Path(model_path)
|
||||||
if model_location.is_file():
|
if model_path.is_file():
|
||||||
return self._hash_file(model_location)
|
return self._hash_file(model_path)
|
||||||
elif model_location.is_dir():
|
elif model_path.is_dir():
|
||||||
return self._hash_dir(model_location)
|
return self._hash_dir(model_path)
|
||||||
else:
|
else:
|
||||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||||
|
|
||||||
def _hash_dir(self, model_location: Path) -> str:
|
def _hash_dir(self, dir: Path) -> str:
|
||||||
"""Compute the hash for all files in a directory and return a hexdigest."""
|
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||||
model_component_paths = self._get_file_paths(model_location)
|
|
||||||
|
Args:
|
||||||
|
dir: Path to the directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Hexdigest of the hash of the directory
|
||||||
|
"""
|
||||||
|
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||||
|
|
||||||
component_hashes: list[str] = []
|
component_hashes: list[str] = []
|
||||||
for component in sorted(model_component_paths):
|
for component in sorted(model_component_paths):
|
||||||
@ -102,43 +116,70 @@ class ModelHash:
|
|||||||
composite_hasher.update(h.encode("utf-8"))
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
return composite_hasher.hexdigest()
|
return composite_hasher.hexdigest()
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def _get_file_paths(cls, dir: Path) -> list[Path]:
|
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||||
"""Return a list of all model files in the directory."""
|
"""Return a list of all model files in the directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model
|
||||||
|
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all model files in the directory
|
||||||
|
"""
|
||||||
|
|
||||||
files: list[Path] = []
|
files: list[Path] = []
|
||||||
for root, _dirs, _files in os.walk(dir):
|
for root, _dirs, _files in os.walk(model_path):
|
||||||
for file in _files:
|
for file in _files:
|
||||||
if file.endswith(MODEL_FILE_EXTENSIONS):
|
if file_filter(file):
|
||||||
files.append(Path(root, file))
|
files.append(Path(root, file))
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _blake3(file_path: Path) -> str:
|
def _blake3(file_path: Path) -> str:
|
||||||
"""Hashes a file using BLAKE3"""
|
"""Hashes a file using BLAKE3
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hexdigest of the hash of the file
|
||||||
|
"""
|
||||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||||
file_hasher.update_mmap(file_path)
|
file_hasher.update_mmap(file_path)
|
||||||
return file_hasher.hexdigest()
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sha1_fast(file_path: Path) -> str:
|
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||||
"""Hashes a file using SHA1, but with a block size of 2**16.
|
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||||
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
|
|
||||||
The algorithm is, however, very fast."""
|
Args:
|
||||||
BLOCK_SIZE = 2**16
|
algorithm: Hashing algorithm to use
|
||||||
file_hash = hashlib.sha1()
|
|
||||||
with open(file_path, "rb") as f:
|
Returns:
|
||||||
data = f.read(BLOCK_SIZE)
|
A function that hashes a file using the given algorithm
|
||||||
file_hash.update(data)
|
"""
|
||||||
return file_hash.hexdigest()
|
|
||||||
|
def hashlib_hasher(file_path: Path) -> str:
|
||||||
|
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||||
|
hasher = hashlib.new(algorithm)
|
||||||
|
buffer = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(buffer)
|
||||||
|
with open(file_path, "rb", buffering=0) as f:
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
hasher.update(mv[:n])
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
return hashlib_hasher
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
def _default_file_filter(file_path: str) -> bool:
|
||||||
"""Hashes a file using a hashlib algorithm"""
|
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||||
|
|
||||||
def hasher(file_path: Path) -> str:
|
Args:
|
||||||
file_hasher = hashlib.new(algorithm)
|
file_path: Path to the file
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
file_hasher.update(f.read())
|
|
||||||
return file_hasher.hexdigest()
|
|
||||||
|
|
||||||
return hasher
|
Returns:
|
||||||
|
True if the file matches the given extensions, otherwise False
|
||||||
|
"""
|
||||||
|
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Iterable
|
|||||||
import pytest
|
import pytest
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
|
|
||||||
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
|
from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
|
||||||
|
|
||||||
test_cases: list[tuple[ALGORITHM, str]] = [
|
test_cases: list[tuple[ALGORITHM, str]] = [
|
||||||
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
||||||
@ -57,24 +57,40 @@ def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
|
|||||||
|
|
||||||
|
|
||||||
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
|
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
|
||||||
model_files = {
|
model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)}
|
||||||
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, f in enumerate(model_files):
|
for i, f in enumerate(model_files):
|
||||||
f.write_text(f"data{i}")
|
f.write_text(f"data{i}")
|
||||||
|
|
||||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||||
|
model_files
|
||||||
|
)
|
||||||
|
|
||||||
# Add file that should be ignored - hash should not change
|
# Add file that should be ignored - hash should not change
|
||||||
file = tmp_path / "test.icecream"
|
file = tmp_path / "test.icecream"
|
||||||
file.write_text("data")
|
file.write_text("data")
|
||||||
|
|
||||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||||
|
model_files
|
||||||
|
)
|
||||||
|
|
||||||
# Add file that should not be ignored - hash should change
|
# Add file that should not be ignored - hash should change
|
||||||
file = tmp_path / "test.bin"
|
file = tmp_path / "test.bin"
|
||||||
file.write_text("more data")
|
file.write_text("more data")
|
||||||
model_files.add(file)
|
model_files.add(file)
|
||||||
|
|
||||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set(
|
||||||
|
model_files
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_uses_custom_filter(tmp_path: Path):
|
||||||
|
model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]}
|
||||||
|
|
||||||
|
for i, f in enumerate(model_files):
|
||||||
|
f.write_text(f"data{i}")
|
||||||
|
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
return file_path.endswith(".pickme")
|
||||||
|
|
||||||
|
assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}
|
||||||
|
Loading…
Reference in New Issue
Block a user