From ec8ed530a756d56bb2449b696c712ae1cebf8feb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 28 Feb 2024 13:06:21 +1100 Subject: [PATCH] feat(mm): modularize ModelHash to facilitate testing --- invokeai/backend/model_manager/hash.py | 39 +++++++++++++++----------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index 3144123761..a7ac014194 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -50,6 +50,7 @@ class ModelHash: :param model_location: Path to the model :param algorithm: Hashing algorithm to use """ + model_location = Path(model_location) if model_location.is_file(): return cls._hash_file(model_location, algorithm) @@ -59,7 +60,7 @@ class ModelHash: raise OSError(f"Not a valid file or directory: {model_location}") @classmethod - def _hash_file(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str: + def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str: """ Compute the hash for a single file and return its hexdigest. @@ -77,44 +78,48 @@ class ModelHash: raise ValueError(f"Algorithm {algorithm} not available") @classmethod - def _hash_dir(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str: + def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str: """ Compute the hash for all files in a directory and return a hexdigest. :param model_location: Path to the model directory :param algorithm: Hashing algorithm to use """ - components: list[str] = [] - - for root, _dirs, files in os.walk(model_location): - for file in files: - # only tally tensor files because diffusers config files change slightly - # depending on how the model was downloaded/converted. - if file.endswith(MODEL_FILE_EXTENSIONS): - components.append((Path(root, file).as_posix())) + model_component_paths = cls._get_file_paths(model_location) component_hashes: list[str] = [] - for component in sorted(components): + for component in sorted(model_component_paths): component_hashes.append(cls._hash_file(component, algorithm)) # BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm # for the composite hash composite_hasher = blake3() - for h in components: + for h in component_hashes: composite_hasher.update(h.encode("utf-8")) return composite_hasher.hexdigest() + @classmethod + def _get_file_paths(cls, dir: Path) -> list[Path]: + """Return a list of all model files in the directory.""" + files: list[Path] = [] + for root, _dirs, _files in os.walk(dir): + for file in _files: + if file.endswith(MODEL_FILE_EXTENSIONS): + files.append(Path(root, file)) + return files + @staticmethod - def _blake3(file_path: Union[str, Path]) -> str: + def _blake3(file_path: Path) -> str: """Hashes a file using BLAKE3""" file_hasher = blake3(max_threads=blake3.AUTO) file_hasher.update_mmap(file_path) return file_hasher.hexdigest() @staticmethod - def _sha1_fast(file_path: Union[str, Path]) -> str: - """Hashes a file using SHA1, but with a block size of 2**16. The result is not a standard SHA1 hash due to the - # padding introduced by the block size. The algorithm is, however, very fast.""" + def _sha1_fast(file_path: Path) -> str: + """Hashes a file using SHA1, but with a block size of 2**16. + The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size. + The algorithm is, however, very fast.""" BLOCK_SIZE = 2**16 file_hash = hashlib.sha1() with open(file_path, "rb") as f: @@ -123,7 +128,7 @@ class ModelHash: return file_hash.hexdigest() @staticmethod - def _hashlib(file_path: Union[str, Path], algorithm: ALGORITHMS) -> str: + def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str: """Hashes a file using a hashlib algorithm""" file_hasher = hashlib.new(algorithm) with open(file_path, "rb") as f: