InvokeAI/invokeai/backend/model_manager/hash.py

145 lines
4.7 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Callable, Literal, Union
from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha1_fast",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class ModelHash:
"""
Creates a hash of a model using a specified algorithm.
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage
```py
ModelHash().hash("path/to/some/model.safetensors")
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "sha1_fast":
self._hash_file = self._sha1_fast
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
def hash(self, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return self._hash_file(model_location)
elif model_location.is_dir():
return self._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
def _hash_dir(self, model_location: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest."""
model_component_paths = self._get_file_paths(model_location)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@classmethod
def _get_file_paths(cls, dir: Path) -> list[Path]:
"""Return a list of all model files in the directory."""
files: list[Path] = []
for root, _dirs, _files in os.walk(dir):
for file in _files:
if file.endswith(MODEL_FILE_EXTENSIONS):
files.append(Path(root, file))
return files
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _sha1_fast(file_path: Path) -> str:
"""Hashes a file using SHA1, but with a block size of 2**16.
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
The algorithm is, however, very fast."""
BLOCK_SIZE = 2**16
file_hash = hashlib.sha1()
with open(file_path, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Hashes a file using a hashlib algorithm"""
def hasher(file_path: Path) -> str:
file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
file_hasher.update(f.read())
return file_hasher.hexdigest()
return hasher