mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): add algorithm prefix to hashes
For example: - md5:a0cd925fc063f98dbf029eee315060c3 - sha1:9e362940e5603fdc60566ea100a288ba2fe48b8c - blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0
This commit is contained in:
parent
a4be935458
commit
9fcd67b5c0
@ -33,7 +33,7 @@ MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
|||||||
|
|
||||||
class ModelHash:
|
class ModelHash:
|
||||||
"""
|
"""
|
||||||
Creates a hash of a model using a specified algorithm.
|
Creates a hash of a model using a specified algorithm. The hash is prefixed by the algorithm used.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||||
@ -53,15 +53,17 @@ class ModelHash:
|
|||||||
Usage:
|
Usage:
|
||||||
```py
|
```py
|
||||||
# BLAKE3 hash
|
# BLAKE3 hash
|
||||||
ModelHash().hash("path/to/some/model.safetensors")
|
ModelHash().hash("path/to/some/model.safetensors") # "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"
|
||||||
# MD5
|
# MD5
|
||||||
ModelHash("md5").hash("path/to/model/dir/")
|
ModelHash("md5").hash("path/to/model/dir/") # "md5:a0cd925fc063f98dbf029eee315060c3"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
|
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# The extra type annotation here is necessary for pyright to understand that algorithm is a Literal type
|
||||||
|
self.algorithm: HASHING_ALGORITHMS = algorithm
|
||||||
if algorithm == "blake3":
|
if algorithm == "blake3":
|
||||||
self._hash_file = self._blake3
|
self._hash_file = self._blake3
|
||||||
elif algorithm == "blake3_single":
|
elif algorithm == "blake3_single":
|
||||||
@ -90,10 +92,12 @@ class ModelHash:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||||
|
prefix = self._get_prefix(self.algorithm)
|
||||||
if model_path.is_file():
|
if model_path.is_file():
|
||||||
return self._hash_file(model_path)
|
return prefix + self._hash_file(model_path)
|
||||||
elif model_path.is_dir():
|
elif model_path.is_dir():
|
||||||
return self._hash_dir(model_path)
|
return prefix + self._hash_dir(model_path)
|
||||||
else:
|
else:
|
||||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||||
|
|
||||||
@ -117,6 +121,7 @@ class ModelHash:
|
|||||||
composite_hasher = blake3()
|
composite_hasher = blake3()
|
||||||
for h in component_hashes:
|
for h in component_hashes:
|
||||||
composite_hasher.update(h.encode("utf-8"))
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
|
|
||||||
return composite_hasher.hexdigest()
|
return composite_hasher.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -207,3 +212,9 @@ class ModelHash:
|
|||||||
True if the file matches the given extensions, otherwise False
|
True if the file matches the given extensions, otherwise False
|
||||||
"""
|
"""
|
||||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_prefix(algorithm: HASHING_ALGORITHMS) -> str:
|
||||||
|
"""Return the prefix for the given algorithm, e.g. \"blake3:\" or \"md5:\"."""
|
||||||
|
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||||
|
return "blake3:" if algorithm == "blake3_single" else f"{algorithm}:"
|
||||||
|
@ -9,14 +9,15 @@ from blake3 import blake3
|
|||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, MODEL_FILE_EXTENSIONS, ModelHash
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, MODEL_FILE_EXTENSIONS, ModelHash
|
||||||
|
|
||||||
test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [
|
test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [
|
||||||
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
("md5", "md5:a0cd925fc063f98dbf029eee315060c3"),
|
||||||
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
("sha1", "sha1:9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
||||||
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
("sha256", "sha256:6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
||||||
(
|
(
|
||||||
"sha512",
|
"sha512",
|
||||||
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
|
"sha512:c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
|
||||||
),
|
),
|
||||||
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
|
("blake3", "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
|
||||||
|
("blake3_single", "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -24,11 +25,11 @@ test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [
|
|||||||
def test_model_hash_hashes_file(tmp_path: Path, algorithm: HASHING_ALGORITHMS, expected_hash: str):
|
def test_model_hash_hashes_file(tmp_path: Path, algorithm: HASHING_ALGORITHMS, expected_hash: str):
|
||||||
file = Path(tmp_path / "test")
|
file = Path(tmp_path / "test")
|
||||||
file.write_text("model data")
|
file.write_text("model data")
|
||||||
md5 = ModelHash(algorithm).hash(file)
|
hash_ = ModelHash(algorithm).hash(file)
|
||||||
assert md5 == expected_hash
|
assert hash_ == expected_hash
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
|
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3", "blake3_single"])
|
||||||
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS):
|
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS):
|
||||||
model_hash = ModelHash(algorithm)
|
model_hash = ModelHash(algorithm)
|
||||||
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
||||||
@ -36,15 +37,33 @@ def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS):
|
|||||||
for f in files:
|
for f in files:
|
||||||
f.write_text("data")
|
f.write_text("data")
|
||||||
|
|
||||||
md5 = model_hash.hash(tmp_path)
|
hash_ = model_hash.hash(tmp_path)
|
||||||
|
|
||||||
# Manual implementation of composite hash - always uses BLAKE3
|
# Manual implementation of composite hash - always uses BLAKE3
|
||||||
|
component_hashes: list[str] = []
|
||||||
|
for f in sorted(ModelHash._get_file_paths(tmp_path, ModelHash._default_file_filter)):
|
||||||
|
component_hashes.append(model_hash._hash_file(f))
|
||||||
|
|
||||||
composite_hasher = blake3()
|
composite_hasher = blake3()
|
||||||
for f in files:
|
for h in component_hashes:
|
||||||
h = model_hash.hash(f)
|
|
||||||
composite_hasher.update(h.encode("utf-8"))
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
|
|
||||||
assert md5 == composite_hasher.hexdigest()
|
assert hash_ == ModelHash._get_prefix(algorithm) + composite_hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"algorithm,expected_prefix",
|
||||||
|
[
|
||||||
|
("md5", "md5:"),
|
||||||
|
("sha1", "sha1:"),
|
||||||
|
("sha256", "sha256:"),
|
||||||
|
("sha512", "sha512:"),
|
||||||
|
("blake3", "blake3:"),
|
||||||
|
("blake3_single", "blake3:"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_model_hash_gets_prefix(algorithm: HASHING_ALGORITHMS, expected_prefix: str):
|
||||||
|
assert ModelHash._get_prefix(algorithm) == expected_prefix
|
||||||
|
|
||||||
|
|
||||||
def test_model_hash_blake3_matches_blake3_single(tmp_path: Path):
|
def test_model_hash_blake3_matches_blake3_single(tmp_path: Path):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user