diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index 1139a1dacf..656b591f4a 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -7,10 +7,11 @@ from invokeai.backend.model_managre.model_hash import FastModelHash >>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') 'a8e693a126ea5b831c96064dc569956f' """ + import hashlib import os from pathlib import Path -from typing import Callable, Literal, Union +from typing import Callable, Literal, Optional, Union from blake3 import blake3 @@ -19,7 +20,6 @@ MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") ALGORITHM = Literal[ "md5", "sha1", - "sha1_fast", "sha224", "sha256", "sha384", @@ -40,7 +40,9 @@ class ModelHash: """ Creates a hash of a model using a specified algorithm. - :param algorithm: Hashing algorithm to use. Defaults to BLAKE3. + Args: + algorithm: Hashing algorithm to use. Defaults to BLAKE3. + file_filter: A function that takes a file name and returns True if the file should be included in the hash. If the model is a single file, it is hashed directly using the provided algorithm. @@ -51,45 +53,57 @@ class ModelHash: The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring that directory hashes are never weaker than the file hashes. - Usage - - ```py - ModelHash().hash("path/to/some/model.safetensors") - ModelHash("md5").hash("path/to/model/dir/") - ``` + Usage: + ```py + # BLAKE3 hash + ModelHash().hash("path/to/some/model.safetensors") + # MD5 + ModelHash("md5").hash("path/to/model/dir/") + ``` """ - def __init__(self, algorithm: ALGORITHM = "blake3") -> None: + def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None: if algorithm == "blake3": self._hash_file = self._blake3 - elif algorithm == "sha1_fast": - self._hash_file = self._sha1_fast elif algorithm in hashlib.algorithms_available: self._hash_file = self._get_hashlib(algorithm) else: raise ValueError(f"Algorithm {algorithm} not available") - def hash(self, model_location: Union[str, Path]) -> str: - """ - Return hexdigest string for model located at model_location. + self._file_filter = file_filter or self._default_file_filter - If model_location is a directory, the hash is computed by hashing the hashes of all model files in the + def hash(self, model_path: Union[str, Path]) -> str: + """ + Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation. + + If model_path is a directory, the hash is computed by hashing the hashes of all model files in the directory. The final composite hash is always computed using BLAKE3. - :param model_location: Path to the model + Args: + model_path: Path to the model + + Returns: + str: Hexdigest of the hash of the model """ - model_location = Path(model_location) - if model_location.is_file(): - return self._hash_file(model_location) - elif model_location.is_dir(): - return self._hash_dir(model_location) + model_path = Path(model_path) + if model_path.is_file(): + return self._hash_file(model_path) + elif model_path.is_dir(): + return self._hash_dir(model_path) else: - raise OSError(f"Not a valid file or directory: {model_location}") + raise OSError(f"Not a valid file or directory: {model_path}") - def _hash_dir(self, model_location: Path) -> str: - """Compute the hash for all files in a directory and return a hexdigest.""" - model_component_paths = self._get_file_paths(model_location) + def _hash_dir(self, dir: Path) -> str: + """Compute the hash for all files in a directory and return a hexdigest. + + Args: + dir: Path to the directory + + Returns: + str: Hexdigest of the hash of the directory + """ + model_component_paths = self._get_file_paths(dir, self._file_filter) component_hashes: list[str] = [] for component in sorted(model_component_paths): @@ -102,43 +116,70 @@ class ModelHash: composite_hasher.update(h.encode("utf-8")) return composite_hasher.hexdigest() - @classmethod - def _get_file_paths(cls, dir: Path) -> list[Path]: - """Return a list of all model files in the directory.""" + @staticmethod + def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]: + """Return a list of all model files in the directory. + + Args: + model_path: Path to the model + file_filter: Function that takes a file name and returns True if the file should be included in the list. + + Returns: + List of all model files in the directory + """ + files: list[Path] = [] - for root, _dirs, _files in os.walk(dir): + for root, _dirs, _files in os.walk(model_path): for file in _files: - if file.endswith(MODEL_FILE_EXTENSIONS): + if file_filter(file): files.append(Path(root, file)) return files @staticmethod def _blake3(file_path: Path) -> str: - """Hashes a file using BLAKE3""" + """Hashes a file using BLAKE3 + + Args: + file_path: Path to the file to hash + + Returns: + Hexdigest of the hash of the file + """ file_hasher = blake3(max_threads=blake3.AUTO) file_hasher.update_mmap(file_path) return file_hasher.hexdigest() @staticmethod - def _sha1_fast(file_path: Path) -> str: - """Hashes a file using SHA1, but with a block size of 2**16. - The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size. - The algorithm is, however, very fast.""" - BLOCK_SIZE = 2**16 - file_hash = hashlib.sha1() - with open(file_path, "rb") as f: - data = f.read(BLOCK_SIZE) - file_hash.update(data) - return file_hash.hexdigest() + def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]: + """Factory function that returns a function to hash a file with the given algorithm. + + Args: + algorithm: Hashing algorithm to use + + Returns: + A function that hashes a file using the given algorithm + """ + + def hashlib_hasher(file_path: Path) -> str: + """Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory.""" + hasher = hashlib.new(algorithm) + buffer = bytearray(128 * 1024) + mv = memoryview(buffer) + with open(file_path, "rb", buffering=0) as f: + while n := f.readinto(mv): + hasher.update(mv[:n]) + return hasher.hexdigest() + + return hashlib_hasher @staticmethod - def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]: - """Hashes a file using a hashlib algorithm""" + def _default_file_filter(file_path: str) -> bool: + """A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth - def hasher(file_path: Path) -> str: - file_hasher = hashlib.new(algorithm) - with open(file_path, "rb") as f: - file_hasher.update(f.read()) - return file_hasher.hexdigest() + Args: + file_path: Path to the file - return hasher + Returns: + True if the file matches the given extensions, otherwise False + """ + return file_path.endswith(MODEL_FILE_EXTENSIONS) diff --git a/tests/test_model_hash.py b/tests/test_model_hash.py index 763aa4fc63..641a150034 100644 --- a/tests/test_model_hash.py +++ b/tests/test_model_hash.py @@ -6,7 +6,7 @@ from typing import Iterable import pytest from blake3 import blake3 -from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash +from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash test_cases: list[tuple[ALGORITHM, str]] = [ ("md5", "a0cd925fc063f98dbf029eee315060c3"), @@ -57,24 +57,40 @@ def paths_to_str_set(paths: Iterable[Path]) -> set[str]: def test_model_hash_filters_out_non_model_files(tmp_path: Path): - model_files = { - Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"]) - } + model_files = {Path(tmp_path, f"{i}{ext}") for i, ext in enumerate(MODEL_FILE_EXTENSIONS)} for i, f in enumerate(model_files): f.write_text(f"data{i}") - assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files) + assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set( + model_files + ) # Add file that should be ignored - hash should not change file = tmp_path / "test.icecream" file.write_text("data") - assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files) + assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set( + model_files + ) # Add file that should not be ignored - hash should change file = tmp_path / "test.bin" file.write_text("more data") model_files.add(file) - assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files) + assert paths_to_str_set(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)) == paths_to_str_set( + model_files + ) + + +def test_model_hash_uses_custom_filter(tmp_path: Path): + model_files = {Path(tmp_path, f"file{ext}") for ext in [".pickme", ".ignoreme"]} + + for i, f in enumerate(model_files): + f.write_text(f"data{i}") + + def file_filter(file_path: str) -> bool: + return file_path.endswith(".pickme") + + assert {p.name for p in ModelHash._get_file_paths(tmp_path, file_filter)} == {"file.pickme"}