mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): add algorithm prefix to hashes
For example: - md5:a0cd925fc063f98dbf029eee315060c3 - sha1:9e362940e5603fdc60566ea100a288ba2fe48b8c - blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0
This commit is contained in:
@ -33,7 +33,7 @@ MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
Creates a hash of a model using a specified algorithm. The hash is prefixed by the algorithm used.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
@ -53,15 +53,17 @@ class ModelHash:
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
ModelHash().hash("path/to/some/model.safetensors") # "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
ModelHash("md5").hash("path/to/model/dir/") # "md5:a0cd925fc063f98dbf029eee315060c3"
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
# The extra type annotation here is necessary for pyright to understand that algorithm is a Literal type
|
||||
self.algorithm: HASHING_ALGORITHMS = algorithm
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm == "blake3_single":
|
||||
@ -90,10 +92,12 @@ class ModelHash:
|
||||
"""
|
||||
|
||||
model_path = Path(model_path)
|
||||
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||
prefix = self._get_prefix(self.algorithm)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
return prefix + self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
return prefix + self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
|
||||
@ -117,6 +121,7 @@ class ModelHash:
|
||||
composite_hasher = blake3()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
@ -207,3 +212,9 @@ class ModelHash:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||
|
||||
@staticmethod
|
||||
def _get_prefix(algorithm: HASHING_ALGORITHMS) -> str:
|
||||
"""Return the prefix for the given algorithm, e.g. \"blake3:\" or \"md5:\"."""
|
||||
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||
return "blake3:" if algorithm == "blake3_single" else f"{algorithm}:"
|
||||
|
Reference in New Issue
Block a user