feat(mm): modularize ModelHash to facilitate testing

This commit is contained in:
psychedelicious 2024-02-28 13:06:21 +11:00
parent 982076d7d7
commit ec8ed530a7

View File

@ -50,6 +50,7 @@ class ModelHash:
:param model_location: Path to the model :param model_location: Path to the model
:param algorithm: Hashing algorithm to use :param algorithm: Hashing algorithm to use
""" """
model_location = Path(model_location) model_location = Path(model_location)
if model_location.is_file(): if model_location.is_file():
return cls._hash_file(model_location, algorithm) return cls._hash_file(model_location, algorithm)
@ -59,7 +60,7 @@ class ModelHash:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod @classmethod
def _hash_file(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str: def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
""" """
Compute the hash for a single file and return its hexdigest. Compute the hash for a single file and return its hexdigest.
@ -77,44 +78,48 @@ class ModelHash:
raise ValueError(f"Algorithm {algorithm} not available") raise ValueError(f"Algorithm {algorithm} not available")
@classmethod @classmethod
def _hash_dir(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str: def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
""" """
Compute the hash for all files in a directory and return a hexdigest. Compute the hash for all files in a directory and return a hexdigest.
:param model_location: Path to the model directory :param model_location: Path to the model directory
:param algorithm: Hashing algorithm to use :param algorithm: Hashing algorithm to use
""" """
components: list[str] = [] model_component_paths = cls._get_file_paths(model_location)
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if file.endswith(MODEL_FILE_EXTENSIONS):
components.append((Path(root, file).as_posix()))
component_hashes: list[str] = [] component_hashes: list[str] = []
for component in sorted(components): for component in sorted(model_component_paths):
component_hashes.append(cls._hash_file(component, algorithm)) component_hashes.append(cls._hash_file(component, algorithm))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm # BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash # for the composite hash
composite_hasher = blake3() composite_hasher = blake3()
for h in components: for h in component_hashes:
composite_hasher.update(h.encode("utf-8")) composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest() return composite_hasher.hexdigest()
@classmethod
def _get_file_paths(cls, dir: Path) -> list[Path]:
"""Return a list of all model files in the directory."""
files: list[Path] = []
for root, _dirs, _files in os.walk(dir):
for file in _files:
if file.endswith(MODEL_FILE_EXTENSIONS):
files.append(Path(root, file))
return files
@staticmethod @staticmethod
def _blake3(file_path: Union[str, Path]) -> str: def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3""" """Hashes a file using BLAKE3"""
file_hasher = blake3(max_threads=blake3.AUTO) file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path) file_hasher.update_mmap(file_path)
return file_hasher.hexdigest() return file_hasher.hexdigest()
@staticmethod @staticmethod
def _sha1_fast(file_path: Union[str, Path]) -> str: def _sha1_fast(file_path: Path) -> str:
"""Hashes a file using SHA1, but with a block size of 2**16. The result is not a standard SHA1 hash due to the """Hashes a file using SHA1, but with a block size of 2**16.
# padding introduced by the block size. The algorithm is, however, very fast.""" The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
The algorithm is, however, very fast."""
BLOCK_SIZE = 2**16 BLOCK_SIZE = 2**16
file_hash = hashlib.sha1() file_hash = hashlib.sha1()
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
@ -123,7 +128,7 @@ class ModelHash:
return file_hash.hexdigest() return file_hash.hexdigest()
@staticmethod @staticmethod
def _hashlib(file_path: Union[str, Path], algorithm: ALGORITHMS) -> str: def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str:
"""Hashes a file using a hashlib algorithm""" """Hashes a file using a hashlib algorithm"""
file_hasher = hashlib.new(algorithm) file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f: with open(file_path, "rb") as f: