diff --git a/invokeai/backend/model_hash/model_hash.py b/invokeai/backend/model_hash/model_hash.py index b1cd93be7b..36f70ca3b9 100644 --- a/invokeai/backend/model_hash/model_hash.py +++ b/invokeai/backend/model_hash/model_hash.py @@ -33,7 +33,7 @@ MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") class ModelHash: """ - Creates a hash of a model using a specified algorithm. + Creates a hash of a model using a specified algorithm. The hash is prefixed by the algorithm used. Args: algorithm: Hashing algorithm to use. Defaults to BLAKE3. @@ -53,15 +53,17 @@ class ModelHash: Usage: ```py # BLAKE3 hash - ModelHash().hash("path/to/some/model.safetensors") + ModelHash().hash("path/to/some/model.safetensors") # "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0" # MD5 - ModelHash("md5").hash("path/to/model/dir/") + ModelHash("md5").hash("path/to/model/dir/") # "md5:a0cd925fc063f98dbf029eee315060c3" ``` """ def __init__( self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None ) -> None: + # The extra type annotation here is necessary for pyright to understand that algorithm is a Literal type + self.algorithm: HASHING_ALGORITHMS = algorithm if algorithm == "blake3": self._hash_file = self._blake3 elif algorithm == "blake3_single": @@ -90,10 +92,12 @@ class ModelHash: """ model_path = Path(model_path) + # blake3_single is a single-threaded version of blake3, prefix should still be "blake3:" + prefix = self._get_prefix(self.algorithm) if model_path.is_file(): - return self._hash_file(model_path) + return prefix + self._hash_file(model_path) elif model_path.is_dir(): - return self._hash_dir(model_path) + return prefix + self._hash_dir(model_path) else: raise OSError(f"Not a valid file or directory: {model_path}") @@ -117,6 +121,7 @@ class ModelHash: composite_hasher = blake3() for h in component_hashes: composite_hasher.update(h.encode("utf-8")) + return composite_hasher.hexdigest() @staticmethod @@ -207,3 +212,9 @@ class ModelHash: True if the file matches the given extensions, otherwise False """ return file_path.endswith(MODEL_FILE_EXTENSIONS) + + @staticmethod + def _get_prefix(algorithm: HASHING_ALGORITHMS) -> str: + """Return the prefix for the given algorithm, e.g. \"blake3:\" or \"md5:\".""" + # blake3_single is a single-threaded version of blake3, prefix should still be "blake3:" + return "blake3:" if algorithm == "blake3_single" else f"{algorithm}:" diff --git a/tests/test_model_hash.py b/tests/test_model_hash.py index e7150633e5..5a3feb4462 100644 --- a/tests/test_model_hash.py +++ b/tests/test_model_hash.py @@ -9,14 +9,15 @@ from blake3 import blake3 from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, MODEL_FILE_EXTENSIONS, ModelHash test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [ - ("md5", "a0cd925fc063f98dbf029eee315060c3"), - ("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"), - ("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"), + ("md5", "md5:a0cd925fc063f98dbf029eee315060c3"), + ("sha1", "sha1:9e362940e5603fdc60566ea100a288ba2fe48b8c"), + ("sha256", "sha256:6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"), ( "sha512", - "c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6", + "sha512:c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6", ), - ("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"), + ("blake3", "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"), + ("blake3_single", "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"), ] @@ -24,11 +25,11 @@ test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [ def test_model_hash_hashes_file(tmp_path: Path, algorithm: HASHING_ALGORITHMS, expected_hash: str): file = Path(tmp_path / "test") file.write_text("model data") - md5 = ModelHash(algorithm).hash(file) - assert md5 == expected_hash + hash_ = ModelHash(algorithm).hash(file) + assert hash_ == expected_hash -@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"]) +@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3", "blake3_single"]) def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS): model_hash = ModelHash(algorithm) files = [Path(tmp_path, f"{i}.bin") for i in range(5)] @@ -36,15 +37,33 @@ def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS): for f in files: f.write_text("data") - md5 = model_hash.hash(tmp_path) + hash_ = model_hash.hash(tmp_path) # Manual implementation of composite hash - always uses BLAKE3 + component_hashes: list[str] = [] + for f in sorted(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)): + component_hashes.append(model_hash._hash_file(f)) + composite_hasher = blake3() - for f in files: - h = model_hash.hash(f) + for h in component_hashes: composite_hasher.update(h.encode("utf-8")) - assert md5 == composite_hasher.hexdigest() + assert hash_ == ModelHash._get_prefix(algorithm) + composite_hasher.hexdigest() + + +@pytest.mark.parametrize( + "algorithm,expected_prefix", + [ + ("md5", "md5:"), + ("sha1", "sha1:"), + ("sha256", "sha256:"), + ("sha512", "sha512:"), + ("blake3", "blake3:"), + ("blake3_single", "blake3:"), + ], +) +def test_model_hash_gets_prefix(algorithm: HASHING_ALGORITHMS, expected_prefix: str): + assert ModelHash._get_prefix(algorithm) == expected_prefix def test_model_hash_blake3_matches_blake3_single(tmp_path: Path):